Skip to content

Commit 5d93071

Browse files
committed
fix test bugs
1 parent 520ca09 commit 5d93071

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

brainpy/math/object_transform/tests/test_autograd.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,8 @@ def f1(x, y):
563563
class Test(bp.BrainPyObject):
564564
def __init__(self):
565565
super(Test, self).__init__()
566-
self.x = bm.array([1., 2., 3.])
567-
self.y = bm.array([10., 5.])
566+
self.x = bm.Variable(bm.array([1., 2., 3.]))
567+
self.y = bm.Variable(bm.array([10., 5.]))
568568

569569
def __call__(self, ):
570570
a = self.x[0] * self.y[0]
@@ -596,8 +596,8 @@ def f1(x, y):
596596
class Test(bp.BrainPyObject):
597597
def __init__(self):
598598
super(Test, self).__init__()
599-
self.x = bm.array([1., 2., 3.])
600-
self.y = bm.array([10., 5.])
599+
self.x = bm.Variable(bm.array([1., 2., 3.]))
600+
self.y = bm.Variable(bm.array([10., 5.]))
601601

602602
def __call__(self, ):
603603
a = self.x[0] * self.y[0]
@@ -629,7 +629,7 @@ def f1(x, y):
629629
class Test(bp.BrainPyObject):
630630
def __init__(self):
631631
super(Test, self).__init__()
632-
self.x = bm.array([1., 2., 3.])
632+
self.x = bm.Variable(bm.array([1., 2., 3.]))
633633

634634
def __call__(self, y):
635635
a = self.x[0] * y[0]
@@ -663,7 +663,7 @@ def f1(x, y):
663663
class Test(bp.BrainPyObject):
664664
def __init__(self):
665665
super(Test, self).__init__()
666-
self.x = bm.array([1., 2., 3.])
666+
self.x = bm.Variable(bm.array([1., 2., 3.]))
667667

668668
def __call__(self, y):
669669
a = self.x[0] * y[0]
@@ -698,7 +698,7 @@ def f1(x, y):
698698
class Test(bp.BrainPyObject):
699699
def __init__(self):
700700
super(Test, self).__init__()
701-
self.x = bm.array([1., 2., 3.])
701+
self.x = bm.Variable(bm.array([1., 2., 3.]))
702702

703703
def __call__(self, y):
704704
a = self.x[0] * y[0]
@@ -737,7 +737,7 @@ def f1(x, y):
737737
class Test(bp.BrainPyObject):
738738
def __init__(self):
739739
super(Test, self).__init__()
740-
self.x = bm.array([1., 2., 3.])
740+
self.x = bm.Variable(bm.array([1., 2., 3.]))
741741

742742
def __call__(self, y):
743743
a = self.x[0] * y[0]
@@ -779,7 +779,7 @@ def f1(x, y):
779779
class Test(bp.BrainPyObject):
780780
def __init__(self):
781781
super(Test, self).__init__()
782-
self.x = bm.array([1., 2., 3.])
782+
self.x = bm.Variable(bm.array([1., 2., 3.]))
783783

784784
def __call__(self, y):
785785
a = self.x[0] * y[0]
@@ -819,7 +819,7 @@ def f1(x, y):
819819
class Test(bp.BrainPyObject):
820820
def __init__(self):
821821
super(Test, self).__init__()
822-
self.x = bm.array([1., 2., 3.])
822+
self.x = bm.Variable(bm.array([1., 2., 3.]))
823823

824824
def __call__(self, y):
825825
a = self.x[0] * y[0]
@@ -968,8 +968,8 @@ def test1(self):
968968
class Test(bp.BrainPyObject):
969969
def __init__(self):
970970
super(Test, self).__init__()
971-
self.x = bm.ones(5)
972-
self.y = bm.ones(5)
971+
self.x = bm.Variable(bm.ones(5))
972+
self.y = bm.Variable(bm.ones(5))
973973

974974
def __call__(self, *args, **kwargs):
975975
return self.x ** 2 + self.y ** 2 + 10

brainpy/math/operators/tests/test_op_register.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,6 @@ def test_op(self):
135135
ax = fig.add_subplot(gs[0, 2])
136136
bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True)
137137

138+
# clear
138139
bm.clear_buffer_memory()
139140
plt.close()

brainpy/math/random.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jax
99
import numpy as np
1010
from jax import lax, jit, vmap, numpy as jnp, random as jr, core
11-
from jax import dtypes
11+
from jax._src import dtypes
1212
from jax.experimental.host_callback import call
1313
from jax.tree_util import register_pytree_node
1414

@@ -75,10 +75,12 @@ def _remove_jax_array(a):
7575

7676

7777
def _const(example, val):
78-
dtype = dtypes.dtype(example, canonicalize=True)
7978
if dtypes.is_python_scalar(example):
79+
dtype = dtypes.canonicalize_dtype(type(example))
8080
val = dtypes.scalar_type_of(example)(val)
8181
return val if dtype == dtypes.dtype(val, canonicalize=True) else np.array(val, dtype)
82+
else:
83+
dtype = dtypes.canonicalize_dtype(example.dtype)
8284
return np.array(val, dtype)
8385

8486

brainpy/math/tests/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_randn(self):
6262
def test_random1(self):
6363
br.seed()
6464
a = br.random()
65-
self.assertIsInstance(a, bm.jaxarray.Array)
65+
self.assertIsInstance(a, bm.Array)
6666
self.assertTrue(0. <= a < 1)
6767

6868
def test_random2(self):

0 commit comments

Comments
 (0)