Skip to content

Commit 48d96a3

Browse files
committed
change TimeDelay.update(t, data) to .update(data)
1 parent 2f829f1 commit 48d96a3

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

brainpy/math/delayvars.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from brainpy.errors import UnsupportedError
1212
from brainpy.math import numpy_ops as bm
1313
from brainpy.math.ndarray import ndarray, Variable, Array
14-
from brainpy.math.setting import get_dt
14+
from brainpy.math.setting import get_dt, dftype
1515
from brainpy.tools.checking import check_float, check_integer
1616
from brainpy.tools.errors import check_error_in_jit
1717

@@ -72,7 +72,7 @@ class TimeDelay(AbstractDelay):
7272
7373
Parameters
7474
----------
75-
delay_target: Array, ndarray, Variable
75+
delay_target: ArrayType
7676
The initial delay data.
7777
t0: float, int
7878
The zero time.
@@ -139,13 +139,14 @@ def __init__(
139139
# time variables
140140
self.idx = Variable(jnp.asarray([0]))
141141
check_float(t0, 't0', allow_none=False, allow_int=True, )
142-
self.current_time = Variable(jnp.asarray([t0]))
142+
self.current_time = Variable(jnp.asarray([t0], dtype=dftype()))
143143

144144
# delay data
145145
batch_axis = None
146146
if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None):
147147
batch_axis = delay_target.batch_axis + 1
148-
self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype),
148+
self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape,
149+
dtype=delay_target.dtype),
149150
batch_axis=batch_axis)
150151
if before_t0 is None:
151152
self._before_type = _DATA_BEFORE
@@ -175,13 +176,13 @@ def reset(self,
175176
176177
Parameters
177178
----------
178-
delay_target: Array, ndarray, Variable
179+
delay_target: ArrayType
179180
The delay target.
180181
delay_len: float, int
181182
The maximum delay length. The unit is the time.
182183
t0: int, float
183184
The zero time.
184-
before_t0: int, float, ndarray, Array
185+
before_t0: int, float, ArrayType
185186
The data before t0.
186187
"""
187188
self.delay_len = delay_len
@@ -212,8 +213,12 @@ def __call__(self, time, indices=None):
212213
# check
213214
if check.is_checking():
214215
current_time = self.current_time[0]
215-
check_error_in_jit(time > current_time + 1e-6, self._check_time1, (time, current_time))
216-
check_error_in_jit(time < current_time - self.delay_len - self.dt, self._check_time2, (time, current_time))
216+
check_error_in_jit(time > current_time + 1e-6,
217+
self._check_time1,
218+
(time, current_time))
219+
check_error_in_jit(time < current_time - self.delay_len - self.dt,
220+
self._check_time2,
221+
(time, current_time))
217222
if self._before_type == _FUNC_BEFORE:
218223
res = cond(time < self.t0,
219224
self._before_t0,
@@ -251,9 +256,9 @@ def _false_fn(self, div_mod):
251256
idx %= self.num_delay_step
252257
return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx])
253258

254-
def update(self, time, value):
259+
def update(self, value):
255260
self.data[self.idx[0]] = value
256-
self.current_time[0] = time
261+
self.current_time += self.dt
257262
self.idx.value = (self.idx + 1) % self.num_delay_step
258263

259264

@@ -411,7 +416,7 @@ def retrieve(self, delay_len, *indices):
411416
412417
Parameters
413418
----------
414-
delay_len: int, Array
419+
delay_len: int, ArrayType
415420
The delay length used to retrieve the data.
416421
"""
417422
if check.is_checking():

0 commit comments

Comments
 (0)