@@ -236,7 +236,7 @@ def __sub__(self, oc):
236236 return JaxArray (self ._value - (oc ._value if isinstance (oc , JaxArray ) else oc ))
237237
238238 def __rsub__ (self , oc ):
239- return JaxArray (self . _value - (oc ._value if isinstance (oc , JaxArray ) else oc ))
239+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) - self . _value )
240240
241241 def __isub__ (self , oc ):
242242 # a -= b
@@ -249,7 +249,7 @@ def __mul__(self, oc):
249249 return JaxArray (self ._value * (oc ._value if isinstance (oc , JaxArray ) else oc ))
250250
251251 def __rmul__ (self , oc ):
252- return JaxArray (self . _value * (oc ._value if isinstance (oc , JaxArray ) else oc ))
252+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) * self . _value )
253253
254254 def __imul__ (self , oc ):
255255 # a *= b
@@ -258,17 +258,17 @@ def __imul__(self, oc):
258258 self ._value = self ._value * (oc ._value if isinstance (oc , JaxArray ) else oc )
259259 return self
260260
261- def __div__ (self , oc ):
262- return JaxArray (self ._value / (oc ._value if isinstance (oc , JaxArray ) else oc ))
261+ # def __div__(self, oc):
262+ # return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
263263
264264 def __rdiv__ (self , oc ):
265- return JaxArray (self . _value / (oc ._value if isinstance (oc , JaxArray ) else oc ))
265+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) / self . _value )
266266
267267 def __truediv__ (self , oc ):
268268 return JaxArray (self ._value / (oc ._value if isinstance (oc , JaxArray ) else oc ))
269269
270270 def __rtruediv__ (self , oc ):
271- return JaxArray (self . _value / (oc ._value if isinstance (oc , JaxArray ) else oc ))
271+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) / self . _value )
272272
273273 def __itruediv__ (self , oc ):
274274 # a /= b
@@ -281,7 +281,7 @@ def __floordiv__(self, oc):
281281 return JaxArray (self ._value // (oc ._value if isinstance (oc , JaxArray ) else oc ))
282282
283283 def __rfloordiv__ (self , oc ):
284- return JaxArray (self . _value // (oc ._value if isinstance (oc , JaxArray ) else oc ))
284+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) // self . _value )
285285
286286 def __ifloordiv__ (self , oc ):
287287 # a //= b
@@ -291,16 +291,16 @@ def __ifloordiv__(self, oc):
291291 return self
292292
293293 def __divmod__ (self , oc ):
294- return JaxArray (self ._value % (oc ._value if isinstance (oc , JaxArray ) else oc ))
294+ return JaxArray (self ._value . __divmod__ (oc ._value if isinstance (oc , JaxArray ) else oc ))
295295
296296 def __rdivmod__ (self , oc ):
297- return JaxArray (self ._value % (oc ._value if isinstance (oc , JaxArray ) else oc ))
297+ return JaxArray (self ._value . __rdivmod__ (oc ._value if isinstance (oc , JaxArray ) else oc ))
298298
299299 def __mod__ (self , oc ):
300300 return JaxArray (self ._value % (oc ._value if isinstance (oc , JaxArray ) else oc ))
301301
302302 def __rmod__ (self , oc ):
303- return JaxArray (self . _value % (oc ._value if isinstance (oc , JaxArray ) else oc ))
303+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) % self . _value )
304304
305305 def __imod__ (self , oc ):
306306 # a %= b
@@ -313,7 +313,7 @@ def __pow__(self, oc):
313313 return JaxArray (self ._value ** (oc ._value if isinstance (oc , JaxArray ) else oc ))
314314
315315 def __rpow__ (self , oc ):
316- return JaxArray (self . _value ** (oc ._value if isinstance (oc , JaxArray ) else oc ))
316+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) ** self . _value )
317317
318318 def __ipow__ (self , oc ):
319319 # a **= b
@@ -326,7 +326,7 @@ def __matmul__(self, oc):
326326 return JaxArray (self ._value @ (oc ._value if isinstance (oc , JaxArray ) else oc ))
327327
328328 def __rmatmul__ (self , oc ):
329- return JaxArray (self . _value @ (oc ._value if isinstance (oc , JaxArray ) else oc ))
329+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) @ self . _value )
330330
331331 def __imatmul__ (self , oc ):
332332 # a @= b
@@ -339,7 +339,7 @@ def __and__(self, oc):
339339 return JaxArray (self ._value & (oc ._value if isinstance (oc , JaxArray ) else oc ))
340340
341341 def __rand__ (self , oc ):
342- return JaxArray (self . _value & (oc ._value if isinstance (oc , JaxArray ) else oc ))
342+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) & self . _value )
343343
344344 def __iand__ (self , oc ):
345345 # a &= b
@@ -352,7 +352,7 @@ def __or__(self, oc):
352352 return JaxArray (self ._value | (oc ._value if isinstance (oc , JaxArray ) else oc ))
353353
354354 def __ror__ (self , oc ):
355- return JaxArray (self . _value | (oc ._value if isinstance (oc , JaxArray ) else oc ))
355+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) | self . _value )
356356
357357 def __ior__ (self , oc ):
358358 # a |= b
@@ -365,7 +365,7 @@ def __xor__(self, oc):
365365 return JaxArray (self ._value ^ (oc ._value if isinstance (oc , JaxArray ) else oc ))
366366
367367 def __rxor__ (self , oc ):
368- return JaxArray (self . _value ^ (oc ._value if isinstance (oc , JaxArray ) else oc ))
368+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) ^ self . _value )
369369
370370 def __ixor__ (self , oc ):
371371 # a ^= b
@@ -378,7 +378,7 @@ def __lshift__(self, oc):
378378 return JaxArray (self ._value << (oc ._value if isinstance (oc , JaxArray ) else oc ))
379379
380380 def __rlshift__ (self , oc ):
381- return JaxArray (self . _value << (oc ._value if isinstance (oc , JaxArray ) else oc ))
381+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) << self . _value )
382382
383383 def __ilshift__ (self , oc ):
384384 # a <<= b
@@ -391,7 +391,7 @@ def __rshift__(self, oc):
391391 return JaxArray (self ._value >> (oc ._value if isinstance (oc , JaxArray ) else oc ))
392392
393393 def __rrshift__ (self , oc ):
394- return JaxArray (self . _value >> (oc ._value if isinstance (oc , JaxArray ) else oc ))
394+ return JaxArray ((oc ._value if isinstance (oc , JaxArray ) else oc ) >> self . _value )
395395
396396 def __irshift__ (self , oc ):
397397 # a >>= b
0 commit comments