Skip to content

Commit e0a3ee1

Browse files
authored
Merge pull request #340 from Breeze-P/master
feat: pytorch math-operations-a
2 parents 176565f + d7e504a commit e0a3ee1

File tree

3 files changed

+193
-2
lines changed

3 files changed

+193
-2
lines changed

brainpy/_src/math/compat_pytorch.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax.numpy as jnp
55
import numpy as np
66

7-
from .ndarray import Array, _as_jax_array_
7+
from .ndarray import Array, _as_jax_array_, _return, _check_out
88
from .compat_numpy import (
99
concatenate, shape
1010
)
@@ -86,3 +86,145 @@ def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array:
8686
"""
8787
input = _as_jax_array_(input)
8888
return Array(jnp.expand_dims(input, dim))
89+
90+
91+
# Math operations
92+
def abs(input: Union[jax.Array, Array],
93+
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
94+
input = _as_jax_array_(input)
95+
r = jnp.abs(input)
96+
if out is None:
97+
return _return(r)
98+
else:
99+
_check_out(out)
100+
out.value = r
101+
102+
absolute = abs
103+
104+
def acos(input: Union[jax.Array, Array],
105+
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
106+
input = _as_jax_array_(input)
107+
r = jnp.arccos(input)
108+
if out is None:
109+
return _return(r)
110+
else:
111+
_check_out(out)
112+
out.value = r
113+
114+
arccos = acos
115+
116+
def acosh(input: Union[jax.Array, Array],
117+
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
118+
input = _as_jax_array_(input)
119+
r = jnp.arccosh(input)
120+
if out is None:
121+
return _return(r)
122+
else:
123+
_check_out(out)
124+
out.value = r
125+
126+
arccosh = acosh
127+
128+
def add(input: Union[jax.Array, Array, jnp.number],
129+
other: Union[jax.Array, Array, jnp.number],
130+
*, alpha: Optional[jnp.number] = 1,
131+
out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
132+
input = _as_jax_array_(input)
133+
other = _as_jax_array_(other)
134+
other = jnp.multiply(alpha, other)
135+
r = jnp.add(input, other)
136+
if out is None:
137+
return _return(r)
138+
else:
139+
_check_out(out)
140+
out.value = r
141+
142+
def addcdiv(input: Union[jax.Array, Array, jnp.number],
143+
tensor1: Union[jax.Array, Array, jnp.number],
144+
tensor2: Union[jax.Array, Array, jnp.number],
145+
*, value: jnp.number = 1,
146+
out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
147+
tensor1 = _as_jax_array_(tensor1)
148+
tensor2 = _as_jax_array_(tensor2)
149+
other = jnp.divide(tensor1, tensor2)
150+
return add(input, other, alpha=value, out=out)
151+
152+
def addcmul(input: Union[jax.Array, Array, jnp.number],
153+
tensor1: Union[jax.Array, Array, jnp.number],
154+
tensor2: Union[jax.Array, Array, jnp.number],
155+
*, value: jnp.number = 1,
156+
out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
157+
tensor1 = _as_jax_array_(tensor1)
158+
tensor2 = _as_jax_array_(tensor2)
159+
other = jnp.multiply(tensor1, tensor2)
160+
return add(input, other, alpha=value, out=out)
161+
162+
def angle(input: Union[jax.Array, Array, jnp.number],
163+
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
164+
input = _as_jax_array_(input)
165+
r = jnp.angle(input)
166+
if out is None:
167+
return _return(r)
168+
else:
169+
_check_out(out)
170+
out.value = r
171+
172+
def asin(input: Union[jax.Array, Array],
173+
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
174+
input = _as_jax_array_(input)
175+
r = jnp.arcsin(input)
176+
if out is None:
177+
return _return(r)
178+
else:
179+
_check_out(out)
180+
out.value = r
181+
182+
arcsin = asin
183+
184+
def asinh(input: Union[jax.Array, Array],
185+
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
186+
input = _as_jax_array_(input)
187+
r = jnp.arcsinh(input)
188+
if out is None:
189+
return _return(r)
190+
else:
191+
_check_out(out)
192+
out.value = r
193+
194+
arcsinh = asinh
195+
196+
def atan(input: Union[jax.Array, Array],
197+
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
198+
input = _as_jax_array_(input)
199+
r = jnp.arctan(input)
200+
if out is None:
201+
return _return(r)
202+
else:
203+
_check_out(out)
204+
out.value = r
205+
206+
arctan = atan
207+
208+
def atanh(input: Union[jax.Array, Array],
209+
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
210+
input = _as_jax_array_(input)
211+
r = jnp.arctanh(input)
212+
if out is None:
213+
return _return(r)
214+
else:
215+
_check_out(out)
216+
out.value = r
217+
218+
arctanh = atanh
219+
220+
def atan2(input: Union[jax.Array, Array],
221+
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
222+
input = _as_jax_array_(input)
223+
r = jnp.arctan2(input)
224+
if out is None:
225+
return _return(r)
226+
else:
227+
_check_out(out)
228+
out.value = r
229+
230+
arctan2 = atan2

brainpy/_src/math/tests/test_compat_pytorch.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88
import brainpy.math as bm
99
from brainpy._src.math import compat_pytorch
10+
import brainpy._src.math.compat_pytorch as torch
1011

1112
from absl .testing import parameterized
1213

@@ -45,3 +46,33 @@ def test1(self):
4546
a = a.expand(1, 6, 4, -1)
4647
self.assertTrue(a.shape == (1, 6, 4, 5))
4748

49+
class TestMathOperators(unittest.TestCase):
50+
def test_abs(self):
51+
arr = compat_pytorch.Tensor([-1, -2, 3])
52+
a = compat_pytorch.abs(arr)
53+
res = compat_pytorch.Tensor([1, 2, 3])
54+
b = compat_pytorch.absolute(arr)
55+
self.assertTrue(bm.array_equal(a, res))
56+
self.assertTrue(bm.array_equal(b, res))
57+
58+
def test_add(self):
59+
a = compat_pytorch.Tensor([0.0202, 1.0985, 1.3506, -0.6056])
60+
a = compat_pytorch.add(a, 20)
61+
res = compat_pytorch.Tensor([20.0202, 21.0985, 21.3506, 19.3944])
62+
self.assertTrue(bm.array_equal(a, res))
63+
b = compat_pytorch.Tensor([-0.9732, -0.3497, 0.6245, 0.4022])
64+
c = compat_pytorch.Tensor([[0.3743], [-1.7724], [-0.5811], [-0.8017]])
65+
b = compat_pytorch.add(b, c, alpha=10)
66+
self.assertTrue(b.shape == (4, 4))
67+
print("b:", b)
68+
69+
def test_addcdiv(self):
70+
rng = bm.random.default_rng(999)
71+
t = rng.rand(1, 3)
72+
t1 = rng.randn(3, 1)
73+
rng = bm.random.default_rng(199)
74+
t2 = rng.randn(1, 3)
75+
res = torch.addcdiv(t, t1, t2, value=0.1)
76+
print("t + t1/t2 * value:", res)
77+
res = torch.addcmul(t, t1, t2, value=0.1)
78+
print("t + t1*t2 * value:", res)

brainpy/math/compat_pytorch.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,23 @@
44

55
flatten as flatten,
66
cat as cat,
7-
7+
unsqueeze as unsqueeze,
8+
abs as abs,
9+
absolute as absolute,
10+
acos as acos,
11+
arccos as arccos,
12+
acosh as acosh,
13+
arccosh as arccosh,
14+
add as add,
15+
addcdiv as addcdiv,
16+
addcmul as addcmul,
17+
angle as angle,
18+
asin as asin,
19+
arcsin as arcsin,
20+
asinh as asinh,
21+
arcsin as arcsin,
22+
atan as atan,
23+
arctan as arctan,
24+
atan2 as atan2,
25+
atanh as atanh,
826
)

0 commit comments

Comments
 (0)