Skip to content

Commit 02ace16

Browse files
committed
fix bugs
1 parent 4dcddce commit 02ace16

File tree

3 files changed

+10
-20
lines changed

3 files changed

+10
-20
lines changed

brainpy/_src/dyn/tests/test_dyn_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class TestDSRunner(unittest.TestCase):
1010
def test1(self):
11-
class ExampleDS(bp.dyn.DynamicalSystem):
11+
class ExampleDS(bp.DynamicalSystem):
1212
def __init__(self):
1313
super(ExampleDS, self).__init__()
1414
self.i = bm.Variable(bm.zeros(1))
@@ -17,23 +17,23 @@ def update(self, tdi):
1717
self.i += 1
1818

1919
ds = ExampleDS()
20-
runner = bp.dyn.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False)
20+
runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False)
2121
runner.run(100.)
2222

2323
def test_t_and_dt(self):
24-
class ExampleDS(bp.dyn.DynamicalSystem):
24+
class ExampleDS(bp.DynamicalSystem):
2525
def __init__(self):
2626
super(ExampleDS, self).__init__()
2727
self.i = bm.Variable(bm.zeros(1))
2828

2929
def update(self, tdi):
3030
self.i += 1 * tdi.dt
3131

32-
runner = bp.dyn.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False)
32+
runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False)
3333
runner.run(100.)
3434

3535
def test_DSView(self):
36-
class EINet(bp.dyn.Network):
36+
class EINet(bp.Network):
3737
def __init__(self, scale=1.0, method='exp_auto'):
3838
super(EINet, self).__init__()
3939

brainpy/_src/math/ndarray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ class VariableView(Variable):
11141114
"""
11151115

11161116
def __init__(self, value: Variable, index):
1117-
self.index = index
1117+
self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, Array))
11181118
if not isinstance(value, Variable):
11191119
raise ValueError('Must be instance of Variable.')
11201120
super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis)
@@ -1137,6 +1137,10 @@ def __repr__(self) -> str:
11371137
def value(self):
11381138
return self._value[self.index]
11391139

1140+
@value.setter
1141+
def value(self, v):
1142+
self.update(v)
1143+
11401144
def update(self, value):
11411145
int_shape = self.shape
11421146
if self.batch_axis is None:

brainpy/_src/math/object_transform/tests/test_jit.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,6 @@
77

88

99
class TestJaxArrayJIT(unittest.TestCase):
10-
def test_jaxarray_outside_jit1(self):
11-
class SomeProgram(bp.BrainPyObject):
12-
def __init__(self):
13-
super(SomeProgram, self).__init__()
14-
self.a = bm.zeros(2)
15-
self.b = bm.Variable(bm.ones(2))
16-
17-
def __call__(self, *args, **kwargs):
18-
self.a[0] += 1
19-
self.b[0] += 1
20-
21-
run = bm.jit(SomeProgram())
22-
with self.assertRaises(bp.errors.MathError):
23-
run()
2410

2511
def test_jaxarray_inside_jit1(self):
2612
bp.math.random.seed()

0 commit comments

Comments
 (0)