44import jax .numpy as jnp
55import numpy as np
66
7- from .ndarray import Array , _as_jax_array_
7+ from .ndarray import Array , _as_jax_array_ , _return , _check_out
88from .compat_numpy import (
99 concatenate , shape
1010)
@@ -86,3 +86,145 @@ def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array:
8686 """
8787 input = _as_jax_array_ (input )
8888 return Array (jnp .expand_dims (input , dim ))
89+
90+
91+ # Math operations
92+ def abs (input : Union [jax .Array , Array ],
93+ * , out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
94+ input = _as_jax_array_ (input )
95+ r = jnp .abs (input )
96+ if out is None :
97+ return _return (r )
98+ else :
99+ _check_out (out )
100+ out .value = r
101+
102+ absolute = abs
103+
104+ def acos (input : Union [jax .Array , Array ],
105+ * , out : Optional [Union [Array ,jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
106+ input = _as_jax_array_ (input )
107+ r = jnp .arccos (input )
108+ if out is None :
109+ return _return (r )
110+ else :
111+ _check_out (out )
112+ out .value = r
113+
114+ arccos = acos
115+
116+ def acosh (input : Union [jax .Array , Array ],
117+ * , out : Optional [Union [Array ,jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
118+ input = _as_jax_array_ (input )
119+ r = jnp .arccosh (input )
120+ if out is None :
121+ return _return (r )
122+ else :
123+ _check_out (out )
124+ out .value = r
125+
126+ arccosh = acosh
127+
128+ def add (input : Union [jax .Array , Array , jnp .number ],
129+ other : Union [jax .Array , Array , jnp .number ],
130+ * , alpha : Optional [jnp .number ] = 1 ,
131+ out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
132+ input = _as_jax_array_ (input )
133+ other = _as_jax_array_ (other )
134+ other = jnp .multiply (alpha , other )
135+ r = jnp .add (input , other )
136+ if out is None :
137+ return _return (r )
138+ else :
139+ _check_out (out )
140+ out .value = r
141+
142+ def addcdiv (input : Union [jax .Array , Array , jnp .number ],
143+ tensor1 : Union [jax .Array , Array , jnp .number ],
144+ tensor2 : Union [jax .Array , Array , jnp .number ],
145+ * , value : jnp .number = 1 ,
146+ out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
147+ tensor1 = _as_jax_array_ (tensor1 )
148+ tensor2 = _as_jax_array_ (tensor2 )
149+ other = jnp .divide (tensor1 , tensor2 )
150+ return add (input , other , alpha = value , out = out )
151+
152+ def addcmul (input : Union [jax .Array , Array , jnp .number ],
153+ tensor1 : Union [jax .Array , Array , jnp .number ],
154+ tensor2 : Union [jax .Array , Array , jnp .number ],
155+ * , value : jnp .number = 1 ,
156+ out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
157+ tensor1 = _as_jax_array_ (tensor1 )
158+ tensor2 = _as_jax_array_ (tensor2 )
159+ other = jnp .multiply (tensor1 , tensor2 )
160+ return add (input , other , alpha = value , out = out )
161+
162+ def angle (input : Union [jax .Array , Array , jnp .number ],
163+ * , out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
164+ input = _as_jax_array_ (input )
165+ r = jnp .angle (input )
166+ if out is None :
167+ return _return (r )
168+ else :
169+ _check_out (out )
170+ out .value = r
171+
172+ def asin (input : Union [jax .Array , Array ],
173+ * , out : Optional [Union [Array ,jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
174+ input = _as_jax_array_ (input )
175+ r = jnp .arcsin (input )
176+ if out is None :
177+ return _return (r )
178+ else :
179+ _check_out (out )
180+ out .value = r
181+
182+ arcsin = asin
183+
184+ def asinh (input : Union [jax .Array , Array ],
185+ * , out : Optional [Union [Array ,jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
186+ input = _as_jax_array_ (input )
187+ r = jnp .arcsinh (input )
188+ if out is None :
189+ return _return (r )
190+ else :
191+ _check_out (out )
192+ out .value = r
193+
194+ arcsinh = asinh
195+
196+ def atan (input : Union [jax .Array , Array ],
197+ * , out : Optional [Union [Array ,jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
198+ input = _as_jax_array_ (input )
199+ r = jnp .arctan (input )
200+ if out is None :
201+ return _return (r )
202+ else :
203+ _check_out (out )
204+ out .value = r
205+
206+ arctan = atan
207+
208+ def atanh (input : Union [jax .Array , Array ],
209+ * , out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
210+ input = _as_jax_array_ (input )
211+ r = jnp .arctanh (input )
212+ if out is None :
213+ return _return (r )
214+ else :
215+ _check_out (out )
216+ out .value = r
217+
218+ arctanh = atanh
219+
220+ def atan2 (input : Union [jax .Array , Array ],
221+ * , out : Optional [Union [Array , jax .Array , np .ndarray ]] = None ) -> Optional [Array ]:
222+ input = _as_jax_array_ (input )
223+ r = jnp .arctan2 (input )
224+ if out is None :
225+ return _return (r )
226+ else :
227+ _check_out (out )
228+ out .value = r
229+
230+ arctan2 = atan2
0 commit comments