|
11 | 11 | from brainpy.errors import UnsupportedError |
12 | 12 | from brainpy.math import numpy_ops as bm |
13 | 13 | from brainpy.math.ndarray import ndarray, Variable, Array |
14 | | -from brainpy.math.setting import get_dt |
| 14 | +from brainpy.math.setting import get_dt, dftype |
15 | 15 | from brainpy.tools.checking import check_float, check_integer |
16 | 16 | from brainpy.tools.errors import check_error_in_jit |
17 | 17 |
|
@@ -72,7 +72,7 @@ class TimeDelay(AbstractDelay): |
72 | 72 |
|
73 | 73 | Parameters |
74 | 74 | ---------- |
75 | | - delay_target: Array, ndarray, Variable |
| 75 | + delay_target: ArrayType |
76 | 76 | The initial delay data. |
77 | 77 | t0: float, int |
78 | 78 | The zero time. |
@@ -139,13 +139,14 @@ def __init__( |
139 | 139 | # time variables |
140 | 140 | self.idx = Variable(jnp.asarray([0])) |
141 | 141 | check_float(t0, 't0', allow_none=False, allow_int=True, ) |
142 | | - self.current_time = Variable(jnp.asarray([t0])) |
| 142 | + self.current_time = Variable(jnp.asarray([t0], dtype=dftype())) |
143 | 143 |
|
144 | 144 | # delay data |
145 | 145 | batch_axis = None |
146 | 146 | if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): |
147 | 147 | batch_axis = delay_target.batch_axis + 1 |
148 | | - self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), |
| 148 | + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, |
| 149 | + dtype=delay_target.dtype), |
149 | 150 | batch_axis=batch_axis) |
150 | 151 | if before_t0 is None: |
151 | 152 | self._before_type = _DATA_BEFORE |
@@ -175,13 +176,13 @@ def reset(self, |
175 | 176 |
|
176 | 177 | Parameters |
177 | 178 | ---------- |
178 | | - delay_target: Array, ndarray, Variable |
| 179 | + delay_target: ArrayType |
179 | 180 | The delay target. |
180 | 181 | delay_len: float, int |
181 | 182 | The maximum delay length. The unit is the time. |
182 | 183 | t0: int, float |
183 | 184 | The zero time. |
184 | | - before_t0: int, float, ndarray, Array |
| 185 | + before_t0: int, float, ArrayType |
185 | 186 | The data before t0. |
186 | 187 | """ |
187 | 188 | self.delay_len = delay_len |
@@ -212,8 +213,12 @@ def __call__(self, time, indices=None): |
212 | 213 | # check |
213 | 214 | if check.is_checking(): |
214 | 215 | current_time = self.current_time[0] |
215 | | - check_error_in_jit(time > current_time + 1e-6, self._check_time1, (time, current_time)) |
216 | | - check_error_in_jit(time < current_time - self.delay_len - self.dt, self._check_time2, (time, current_time)) |
| 216 | + check_error_in_jit(time > current_time + 1e-6, |
| 217 | + self._check_time1, |
| 218 | + (time, current_time)) |
| 219 | + check_error_in_jit(time < current_time - self.delay_len - self.dt, |
| 220 | + self._check_time2, |
| 221 | + (time, current_time)) |
217 | 222 | if self._before_type == _FUNC_BEFORE: |
218 | 223 | res = cond(time < self.t0, |
219 | 224 | self._before_t0, |
@@ -251,9 +256,9 @@ def _false_fn(self, div_mod): |
251 | 256 | idx %= self.num_delay_step |
252 | 257 | return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx]) |
253 | 258 |
|
254 | | - def update(self, time, value): |
| 259 | + def update(self, value): |
255 | 260 | self.data[self.idx[0]] = value |
256 | | - self.current_time[0] = time |
| 261 | + self.current_time += self.dt |
257 | 262 | self.idx.value = (self.idx + 1) % self.num_delay_step |
258 | 263 |
|
259 | 264 |
|
@@ -411,7 +416,7 @@ def retrieve(self, delay_len, *indices): |
411 | 416 |
|
412 | 417 | Parameters |
413 | 418 | ---------- |
414 | | - delay_len: int, Array |
| 419 | + delay_len: int, ArrayType |
415 | 420 | The delay length used to retrieve the data. |
416 | 421 | """ |
417 | 422 | if check.is_checking(): |
|
0 commit comments