Skip to content

Commit 536c3a8

Browse files
committed
fix bugs in LengthDelay and VariableView
1 parent 4ca22f6 commit 536c3a8

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

brainpy/math/delayvars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
435435
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
436436

437437
elif self.update_method == CONCAT_UPDATING:
438-
self.data.value = bm.concatenate([self.data[1:], bm.broadcast_to(value, self.delay_target_shape)], axis=0)
438+
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])
439439

440440
else:
441441
raise ValueError(f'Unknown updating method "{self.update_method}"')

brainpy/math/jaxarray.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,3 +1673,23 @@ def update(self, value):
16731673
f"while we got {value.dtype}.")
16741674
self._value[self.index] = value.value if isinstance(value, JaxArray) else value
16751675

1676+
@value.setter
1677+
def value(self, value):
1678+
int_shape = self.shape
1679+
if self.batch_axis is None:
1680+
ext_shape = value.shape
1681+
else:
1682+
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
1683+
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
1684+
if ext_shape != int_shape:
1685+
error = f"The shape of the original data is {int_shape}, while we got {value.shape}"
1686+
if self.batch_axis is None:
1687+
error += '. Do you forget to set "batch_axis" when initialize this variable?'
1688+
else:
1689+
error += f' with batch_axis={self.batch_axis}.'
1690+
raise MathError(error)
1691+
if value.dtype != self._value.dtype:
1692+
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
1693+
f"while we got {value.dtype}.")
1694+
self._value[self.index] = value.value if isinstance(value, JaxArray) else value
1695+

brainpy/math/tests/test_jaxarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,36 @@ def test_variable_init(self):
5858
not bm.array_equal(bm.Variable(bm.random.rand(10)),
5959
bm.Variable(10))
6060
)
61+
62+
63+
class TestVariableView(unittest.TestCase):
64+
def test_update(self):
65+
origin = bm.Variable(bm.zeros(10))
66+
view = bm.VariableView(origin, slice(0, 5, None))
67+
68+
view.update(bm.ones(5))
69+
self.assertTrue(
70+
bm.array_equal(origin, bm.concatenate([bm.ones(5), bm.zeros(5)]))
71+
)
72+
73+
view.value = bm.arange(5.)
74+
self.assertTrue(
75+
bm.array_equal(origin, bm.concatenate([bm.arange(5), bm.zeros(5)]))
76+
)
77+
78+
view += 10
79+
self.assertTrue(
80+
bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)]))
81+
)
82+
83+
bm.random.shuffle(view)
84+
print(view)
85+
print(origin)
86+
87+
view.sort()
88+
self.assertTrue(
89+
bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)]))
90+
)
91+
92+
self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10))
93+

0 commit comments

Comments
 (0)