Skip to content

Commit 697129e

Browse files
committed
speedup delay retrieval by reversing delay variable data
1 parent 34b18cd commit 697129e

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

brainpy/math/delayvars.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,24 @@ class LengthDelay(AbstractDelay):
280280
It can also be arrays. Or a callable function or instance of ``Connector``.
281281
Note that ``initial_delay_data`` should be arranged as the following way::
282282
283-
delay = delay_len [ data
284-
delay = delay_len-1 data
283+
delay = 1 [ data
284+
delay = 2 data
285285
... ....
286286
... ....
287-
delay = 2 data
288-
delay = 1 data ]
287+
delay = delay_len-1 data
288+
delay = delay_len data ]
289+
290+
.. versionchanged:: 2.2.3.2
291+
292+
The data in the previous version of ``LengthDelay`` is::
293+
294+
delay = delay_len [ data
295+
delay = delay_len-1 data
296+
... ....
297+
... ....
298+
delay = 2 data
299+
delay = 1 data ]
300+
289301
290302
name: str
291303
The delay object name.
@@ -368,13 +380,13 @@ def reset(
368380
dtype=delay_target.dtype)
369381

370382
# update delay data
371-
self.data[-1] = delay_target
383+
self.data[0] = delay_target
372384
if initial_delay_data is None:
373385
pass
374386
elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)):
375-
self.data[:-1] = initial_delay_data
387+
self.data[1:] = initial_delay_data
376388
elif callable(initial_delay_data):
377-
self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape,
389+
self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape,
378390
dtype=delay_target.dtype)
379391
else:
380392
raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}')
@@ -406,20 +418,22 @@ def retrieve(self, delay_len, *indices):
406418
check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len)
407419

408420
if self.update_method == ROTATION_UPDATING:
409-
# the delay length
410-
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step
421+
delay_idx = (self.idx[0] + delay_len) % self.num_delay_step
411422
delay_idx = stop_gradient(delay_idx)
412-
if not jnp.issubdtype(delay_idx.dtype, jnp.integer):
413-
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')
414423

415424
elif self.update_method == CONCAT_UPDATING:
416-
delay_idx = self.num_delay_step - 1 - delay_len
425+
delay_idx = delay_len
417426

418427
else:
419428
raise ValueError(f'Unknown updating method "{self.update_method}"')
420429

421-
# the delay data
430+
# the delay index
431+
if isinstance(delay_idx, int):
432+
pass
433+
elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
434+
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
422435
indices = (delay_idx,) + tuple(indices)
436+
# the delay data
423437
return self.data[indices]
424438

425439
def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
@@ -435,7 +449,10 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
435449
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
436450

437451
elif self.update_method == CONCAT_UPDATING:
438-
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])
452+
if self.num_delay_step >= 2:
453+
self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]])
454+
else:
455+
self.data[:] = value
439456

440457
else:
441458
raise ValueError(f'Unknown updating method "{self.update_method}"')

brainpy/math/tests/test_delay_vars.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def test2(self):
9494
dim = 3
9595
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
9696
delay = bm.LengthDelay(jnp.zeros(dim), 10,
97-
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
97+
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
98+
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
9899
update_method=update_method)
99100
print(delay(0))
100101
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
@@ -111,7 +112,8 @@ def test3(self):
111112
dim = 3
112113
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
113114
delay = bm.LengthDelay(jnp.zeros(dim), 10,
114-
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
115+
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
116+
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
115117
update_method=update_method)
116118
print(delay(jnp.asarray([1, 2, 3]),
117119
jnp.arange(3)))

0 commit comments

Comments
 (0)