Skip to content

Commit 23ae81c

Browse files
authored
fix #149 (dozens of random samplings in NumPy) and fix JaxArray op errors (#216)
fix #149 (dozens of random samplings in NumPy) and fix JaxArray op errors
2 parents e2f5170 + 184d720 commit 23ae81c

File tree

7 files changed

+1218
-338
lines changed

7 files changed

+1218
-338
lines changed

.github/workflows/Windows_CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
3131
python -m pip install numpy==1.21.0
32-
python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
32+
python -m pip install "jax[cpu]==0.3.2" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
3333
python -m pip install -r requirements-win.txt
3434
python -m pip install tqdm brainpylib
3535
python setup.py install

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
publishment.md
22
#experimental/
33
.vscode
4-
4+
io_test_tmp*
55

66
brainpy/base/tests/io_test_tmp*
77

brainpy/math/jaxarray.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __sub__(self, oc):
236236
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))
237237

238238
def __rsub__(self, oc):
239-
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))
239+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) - self._value)
240240

241241
def __isub__(self, oc):
242242
# a -= b
@@ -249,7 +249,7 @@ def __mul__(self, oc):
249249
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))
250250

251251
def __rmul__(self, oc):
252-
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))
252+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) * self._value)
253253

254254
def __imul__(self, oc):
255255
# a *= b
@@ -258,17 +258,17 @@ def __imul__(self, oc):
258258
self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc)
259259
return self
260260

261-
def __div__(self, oc):
262-
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
261+
# def __div__(self, oc):
262+
# return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
263263

264264
def __rdiv__(self, oc):
265-
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
265+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)
266266

267267
def __truediv__(self, oc):
268268
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
269269

270270
def __rtruediv__(self, oc):
271-
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
271+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)
272272

273273
def __itruediv__(self, oc):
274274
# a /= b
@@ -281,7 +281,7 @@ def __floordiv__(self, oc):
281281
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))
282282

283283
def __rfloordiv__(self, oc):
284-
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))
284+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) // self._value)
285285

286286
def __ifloordiv__(self, oc):
287287
# a //= b
@@ -291,16 +291,16 @@ def __ifloordiv__(self, oc):
291291
return self
292292

293293
def __divmod__(self, oc):
294-
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
294+
return JaxArray(self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc))
295295

296296
def __rdivmod__(self, oc):
297-
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
297+
return JaxArray(self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc))
298298

299299
def __mod__(self, oc):
300300
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
301301

302302
def __rmod__(self, oc):
303-
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
303+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) % self._value)
304304

305305
def __imod__(self, oc):
306306
# a %= b
@@ -313,7 +313,7 @@ def __pow__(self, oc):
313313
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))
314314

315315
def __rpow__(self, oc):
316-
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))
316+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ** self._value)
317317

318318
def __ipow__(self, oc):
319319
# a **= b
@@ -326,7 +326,7 @@ def __matmul__(self, oc):
326326
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))
327327

328328
def __rmatmul__(self, oc):
329-
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))
329+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) @ self._value)
330330

331331
def __imatmul__(self, oc):
332332
# a @= b
@@ -339,7 +339,7 @@ def __and__(self, oc):
339339
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))
340340

341341
def __rand__(self, oc):
342-
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))
342+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) & self._value)
343343

344344
def __iand__(self, oc):
345345
# a &= b
@@ -352,7 +352,7 @@ def __or__(self, oc):
352352
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))
353353

354354
def __ror__(self, oc):
355-
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))
355+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) | self._value)
356356

357357
def __ior__(self, oc):
358358
# a |= b
@@ -365,7 +365,7 @@ def __xor__(self, oc):
365365
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))
366366

367367
def __rxor__(self, oc):
368-
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))
368+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ^ self._value)
369369

370370
def __ixor__(self, oc):
371371
# a ^= b
@@ -378,7 +378,7 @@ def __lshift__(self, oc):
378378
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))
379379

380380
def __rlshift__(self, oc):
381-
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))
381+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) << self._value)
382382

383383
def __ilshift__(self, oc):
384384
# a <<= b
@@ -391,7 +391,7 @@ def __rshift__(self, oc):
391391
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))
392392

393393
def __rrshift__(self, oc):
394-
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))
394+
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) >> self._value)
395395

396396
def __irshift__(self, oc):
397397
# a >>= b

0 commit comments

Comments
 (0)