Skip to content

Commit b5215cf

Browse files
committed
rename 'FixedLenDelay' to 'TimeDelay'
1 parent d1665f4 commit b5215cf

File tree

7 files changed

+75
-53
lines changed

7 files changed

+75
-53
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,18 @@ runner.run(100.)
147147

148148

149149

150-
Numerical methods for delay differential equations (SDEs).
150+
Numerical methods for delay differential equations (SDEs).
151151

152152
```python
153-
xdelay = bm.FixedLenDelay(1, delay_len=1., before_t0=1., dt=0.01)
153+
xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01)
154+
154155

155156
@bp.ddeint(method='rk4', state_delays={'x': xdelay})
156157
def second_order_eq(x, y, t):
157-
dx = y
158-
dy = -y - 2*x - 0.5*xdelay(t-1)
159-
return dx, dy
158+
dx = y
159+
dy = -y - 2 * x - 0.5 * xdelay(t - 1)
160+
return dx, dy
161+
160162

161163
runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01)
162164
runner.run(100.)

brainpy/datasets/chaotic_systems.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
167167
assert isinstance(inits, (bm.ndarray, jnp.ndarray))
168168

169169
rng = bm.random.RandomState(seed)
170-
xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt)
170+
xdelay = bm.TimeDelay(inits.shape, tau, dt=dt)
171171
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5)
172172

173173
@ddeint(method=method, state_delays={'x': xdelay})

brainpy/dyn/neurons/rate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(self,
211211
# variables
212212
self.w = bm.Variable(bm.zeros(self.num))
213213
self.V = bm.Variable(bm.zeros(self.num))
214-
self.Vdelay = bm.FixedLenDelay(self.num, self.delay, interp_method='round')
214+
self.Vdelay = bm.TimeDelay(self.num, self.delay, interp_method='round')
215215
self.input = bm.Variable(bm.zeros(self.num))
216216

217217
# integral

brainpy/integrators/dde/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
dt: Union[float, int] = None,
2626
name: str = None,
2727
show_code: bool = False,
28-
state_delays: Dict[str, bm.FixedLenDelay] = None,
28+
state_delays: Dict[str, bm.TimeDelay] = None,
2929
neutral_delays: Dict[str, bm.NeutralDelay] = None,
3030
):
3131
dt = bm.get_dt() if dt is None else dt
@@ -59,7 +59,7 @@ def __init__(
5959
# delays
6060
self._state_delays = dict()
6161
if state_delays is not None:
62-
check_dict_data(state_delays, key_type=str, val_type=bm.FixedLenDelay)
62+
check_dict_data(state_delays, key_type=str, val_type=bm.TimeDelay)
6363
for key, delay in state_delays.items():
6464
if key not in self.variables:
6565
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}')

brainpy/integrators/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class IntegratorRunner(Runner):
9393
>>> dt = 0.01; beta=2.; gamma=1.; tau=2.; n=9.65
9494
>>> mg_eq = lambda x, t, xdelay: (beta * xdelay(t - tau) / (1 + xdelay(t - tau) ** n)
9595
>>> - gamma * x)
96-
>>> xdelay = bm.FixedLenDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2)
96+
>>> xdelay = bm.TimeDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2)
9797
>>> integral = bp.ddeint(mg_eq, method='rk4', state_delays={'x': xdelay})
9898
>>> runner = bp.integrators.IntegratorRunner(
9999
>>> integral,

brainpy/math/delay_vars.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
# -*- coding: utf-8 -*-
22

3-
3+
import warnings
44
from typing import Union, Callable, Tuple
55

66
import jax.numpy as jnp
7+
import numpy as np
78
from jax import vmap
89
from jax.experimental.host_callback import id_tap
910
from jax.lax import cond
1011

12+
from brainpy import check
1113
from brainpy import math as bm
1214
from brainpy.base.base import Base
15+
from brainpy.errors import UnsupportedError
1316
from brainpy.tools.checking import check_float
1417
from brainpy.tools.others import to_size
15-
from brainpy.errors import UnsupportedError
1618

1719
__all__ = [
1820
'AbstractDelay',
21+
'TimeDelay',
1922
'FixedLenDelay',
2023
'NeutralDelay',
2124
]
@@ -32,35 +35,35 @@ def update(self, time, value):
3235
_INTERP_ROUND = 'round'
3336

3437

35-
class FixedLenDelay(AbstractDelay):
36-
"""Delay variable which has a fixed delay length.
38+
class TimeDelay(AbstractDelay):
39+
"""Delay variable which has a fixed delay time length.
3740
3841
For example, we create a delay variable which has a maximum delay length of 1 ms
3942
4043
>>> import brainpy.math as bm
41-
>>> delay = bm.FixedLenDelay(bm.zeros(3), delay_len=1., dt=0.1)
44+
>>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1)
4245
>>> delay(-0.5)
4346
[-0. -0. -0.]
4447
4548
This function supports multiple dimensions of the tensor. For example,
4649
4750
1. the one-dimensional delay data
4851
49-
>>> delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
52+
>>> delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
5053
>>> delay(-0.2)
5154
[-0.2 -0.2 -0.2]
5255
5356
2. the two-dimensional delay data
5457
55-
>>> delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
58+
>>> delay = bm.TimeDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
5659
>>> delay(-0.6)
5760
[[-0.6 -0.6]
5861
[-0.6 -0.6]
5962
[-0.6 -0.6]]
6063
6164
3. the three-dimensional delay data
6265
63-
>>> delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
66+
>>> delay = bm.TimeDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
6467
>>> delay(-0.6)
6568
[[[-0.8]
6669
[-0.8]]
@@ -113,7 +116,7 @@ def __init__(
113116
dtype=None,
114117
interp_method='linear_interp',
115118
):
116-
super(FixedLenDelay, self).__init__(name=name)
119+
super(TimeDelay, self).__init__(name=name)
117120

118121
# shape
119122
self.shape = to_size(shape)
@@ -161,6 +164,10 @@ def __init__(
161164
else:
162165
raise ValueError(f'"before_t0" does not support {type(before_t0)}: before_t0')
163166

167+
self.f = jnp.interp
168+
for dim in range(1, len(self.shape) + 1, 1):
169+
self.f = vmap(self.f, in_axes=(None, None, dim), out_axes=dim - 1)
170+
164171
@property
165172
def idx(self):
166173
return self._idx
@@ -191,36 +198,37 @@ def current_time(self):
191198

192199
def _check_time(self, times, transforms):
193200
prev_time, current_time = times
194-
current_time = bm.as_device_array(current_time)
195-
prev_time = bm.as_device_array(prev_time)
201+
current_time = np.asarray(current_time, dtype=bm.float_)
202+
prev_time = np.asarray(prev_time, dtype=bm.float_)
196203
if prev_time > current_time:
197204
raise ValueError(f'\n'
198205
f'!!! Error in {self.__class__.__name__}: \n'
199206
f'The request time should be less than the '
200207
f'current time {current_time}. But we '
201208
f'got {prev_time} > {current_time}')
202-
lower_time = jnp.asarray(current_time - self.delay_len)
209+
lower_time = np.asarray(current_time - self.delay_len)
203210
if prev_time < lower_time:
204211
raise ValueError(f'\n'
205212
f'!!! Error in {self.__class__.__name__}: \n'
206213
f'The request time of the variable should be in '
207214
f'[{lower_time}, {current_time}], but we got {prev_time}')
208215

209-
def __call__(self, prev_time):
216+
def __call__(self, time, indices=None):
210217
# check
211-
id_tap(self._check_time, (prev_time, self.current_time))
218+
if check.is_checking():
219+
id_tap(self._check_time, (time, self.current_time))
212220
if self._before_type == _FUNC_BEFORE:
213-
return cond(prev_time < self.t0,
221+
return cond(time < self.t0,
214222
self._before_t0,
215223
self._after_t0,
216-
prev_time)
224+
time)
217225
else:
218-
return self._after_t0(prev_time)
226+
return self._after_t0(time)
219227

220228
def _after_t0(self, prev_time):
221229
diff = self.delay_len - (self.current_time - prev_time)
222-
if isinstance(diff, bm.ndarray): diff = diff.value
223-
230+
if isinstance(diff, bm.ndarray):
231+
diff = diff.value
224232
if self.interp_method == _INTERP_LINEAR:
225233
req_num_step = jnp.asarray(diff / self._dt, dtype=bm.get_dint())
226234
extra = diff - req_num_step * self._dt
@@ -238,31 +246,43 @@ def _true_fn(self, div_mod):
238246

239247
def _false_fn(self, div_mod):
240248
req_num_step, extra = div_mod
241-
f = jnp.interp
242-
for dim in range(1, len(self.shape) + 1, 1):
243-
f = vmap(f, in_axes=(None, None, dim), out_axes=dim - 1)
244249
idx = jnp.asarray([self.idx[0] + req_num_step,
245250
self.idx[0] + req_num_step + 1])
246251
idx %= self.num_delay_step
247-
return f(extra, jnp.asarray([0., self._dt]), self._data[idx])
252+
return self.f(extra, jnp.asarray([0., self._dt]), self._data[idx])
248253

249254
def update(self, time, value):
250255
self._data[self._idx[0]] = value
251256
self._current_time[0] = time
252257
self._idx.value = (self._idx + 1) % self.num_delay_step
253258

254259

255-
class VariedLenDelay(AbstractDelay):
256-
"""Delay variable which has a functional delay
257-
258-
"""
260+
def FixedLenDelay(shape: Union[int, Tuple[int, ...]],
261+
delay_len: Union[float, int],
262+
before_t0: Union[Callable, bm.ndarray, jnp.ndarray, float, int] = None,
263+
t0: Union[float, int] = 0.,
264+
dt: Union[float, int] = None,
265+
name: str = None,
266+
dtype=None,
267+
interp_method='linear_interp', ):
268+
warnings.warn('Please use "brainpy.math.TimeDelay" instead. '
269+
'"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ',
270+
DeprecationWarning)
271+
return TimeDelay(shape=shape,
272+
delay_len=delay_len,
273+
before_t0=before_t0,
274+
t0=t0,
275+
dt=dt,
276+
name=name,
277+
dtype=dtype,
278+
interp_method=interp_method)
279+
280+
281+
class NeutralDelay(TimeDelay):
282+
pass
259283

260-
def update(self, time, value):
261-
pass
262284

263-
def __init__(self):
264-
super(VariedLenDelay, self).__init__()
285+
class LengthDelay(AbstractDelay):
286+
pass
265287

266288

267-
class NeutralDelay(FixedLenDelay):
268-
pass

brainpy/math/tests/test_delay_vars.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_dim1(self):
1212
# linear interp
1313
t0 = 0.
1414
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
15-
delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
15+
delay = bm.TimeDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
1616
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10))
1717
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 9.5))
1818
print()
@@ -21,8 +21,8 @@ def test_dim1(self):
2121
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones(10) * 8.7))
2222

2323
# round interp
24-
delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0,
25-
interp_method='round')
24+
delay = bm.TimeDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0,
25+
interp_method='round')
2626
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10))
2727
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 10))
2828
self.assertTrue(bm.array_equal(delay(t0 - 0.2), bm.ones(10) * 9))
@@ -31,7 +31,7 @@ def test_dim2(self):
3131
t0 = 0.
3232
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
3333
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
34-
delay = bm.FixedLenDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
34+
delay = bm.TimeDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
3535
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 10))
3636
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 9.5))
3737
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5)) * 8.7))
@@ -41,27 +41,27 @@ def test_dim3(self):
4141
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
4242
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
4343
before_t0 = bm.repeat(before_t0.reshape((11, 10, 5, 1)), 3, axis=3)
44-
delay = bm.FixedLenDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
44+
delay = bm.TimeDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0)
4545
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 10))
4646
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 9.5))
4747
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 8.7))
4848

4949
def test1(self):
5050
print()
51-
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
51+
delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
5252
print(delay(-0.2))
53-
delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
53+
delay = bm.TimeDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
5454
print(delay(-0.6))
55-
delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
55+
delay = bm.TimeDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
5656
print(delay(-0.8))
5757

5858
def test_current_time2(self):
5959
print()
60-
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
60+
delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t)
6161
print(delay(0.))
6262
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1)
6363
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2)
64-
delay = bm.FixedLenDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0)
64+
delay = bm.TimeDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0)
6565
print(delay(0.))
6666

6767
# def test_prev_time_beyond_boundary(self):

0 commit comments

Comments
 (0)