Skip to content

Commit e5a17c8

Browse files
authored
speedup delay retrieval by reversing delay variable data (#279)
speedup delay retrieval by reversing delay variable data
2 parents 7693f14 + 697129e commit e5a17c8

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

.github/workflows/Linux_CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
31-
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
3231
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
3332
python setup.py install
3433
- name: Lint with flake8
3534
run: |
3635
# stop the build if there are Python syntax errors or undefined names
3736
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
3837
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
39-
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4039
- name: Test with pytest
4140
run: |
4241
pytest brainpy/

.github/workflows/MacOS_CI.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,14 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
31-
python -m pip install jax==0.3.14
32-
python -m pip install jaxlib==0.3.14
3331
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
3432
python setup.py install
3533
- name: Lint with flake8
3634
run: |
3735
# stop the build if there are Python syntax errors or undefined names
3836
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
3937
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
40-
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4139
- name: Test with pytest
4240
run: |
4341
pytest brainpy/

.github/workflows/Windows_CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
# stop the build if there are Python syntax errors or undefined names
4040
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
4141
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
42-
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
42+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4343
- name: Test with pytest
4444
run: |
4545
pytest brainpy/

brainpy/math/delayvars.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,24 @@ class LengthDelay(AbstractDelay):
280280
It can also be arrays. Or a callable function or instance of ``Connector``.
281281
Note that ``initial_delay_data`` should be arranged as the following way::
282282
283-
delay = delay_len [ data
284-
delay = delay_len-1 data
283+
delay = 1 [ data
284+
delay = 2 data
285285
... ....
286286
... ....
287-
delay = 2 data
288-
delay = 1 data ]
287+
delay = delay_len-1 data
288+
delay = delay_len data ]
289+
290+
.. versionchanged:: 2.2.3.2
291+
292+
The data in the previous version of ``LengthDelay`` is::
293+
294+
delay = delay_len [ data
295+
delay = delay_len-1 data
296+
... ....
297+
... ....
298+
delay = 2 data
299+
delay = 1 data ]
300+
289301
290302
name: str
291303
The delay object name.
@@ -368,13 +380,13 @@ def reset(
368380
dtype=delay_target.dtype)
369381

370382
# update delay data
371-
self.data[-1] = delay_target
383+
self.data[0] = delay_target
372384
if initial_delay_data is None:
373385
pass
374386
elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)):
375-
self.data[:-1] = initial_delay_data
387+
self.data[1:] = initial_delay_data
376388
elif callable(initial_delay_data):
377-
self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape,
389+
self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape,
378390
dtype=delay_target.dtype)
379391
else:
380392
raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}')
@@ -406,20 +418,22 @@ def retrieve(self, delay_len, *indices):
406418
check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len)
407419

408420
if self.update_method == ROTATION_UPDATING:
409-
# the delay length
410-
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step
421+
delay_idx = (self.idx[0] + delay_len) % self.num_delay_step
411422
delay_idx = stop_gradient(delay_idx)
412-
if not jnp.issubdtype(delay_idx.dtype, jnp.integer):
413-
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')
414423

415424
elif self.update_method == CONCAT_UPDATING:
416-
delay_idx = self.num_delay_step - 1 - delay_len
425+
delay_idx = delay_len
417426

418427
else:
419428
raise ValueError(f'Unknown updating method "{self.update_method}"')
420429

421-
# the delay data
430+
# the delay index
431+
if isinstance(delay_idx, int):
432+
pass
433+
elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
434+
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
422435
indices = (delay_idx,) + tuple(indices)
436+
# the delay data
423437
return self.data[indices]
424438

425439
def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
@@ -435,7 +449,10 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
435449
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
436450

437451
elif self.update_method == CONCAT_UPDATING:
438-
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])
452+
if self.num_delay_step >= 2:
453+
self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]])
454+
else:
455+
self.data[:] = value
439456

440457
else:
441458
raise ValueError(f'Unknown updating method "{self.update_method}"')

brainpy/math/tests/test_delay_vars.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def test2(self):
9494
dim = 3
9595
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
9696
delay = bm.LengthDelay(jnp.zeros(dim), 10,
97-
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
97+
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
98+
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
9899
update_method=update_method)
99100
print(delay(0))
100101
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
@@ -111,7 +112,8 @@ def test3(self):
111112
dim = 3
112113
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
113114
delay = bm.LengthDelay(jnp.zeros(dim), 10,
114-
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
115+
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
116+
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
115117
update_method=update_method)
116118
print(delay(jnp.asarray([1, 2, 3]),
117119
jnp.arange(3)))

0 commit comments

Comments
 (0)