11# -*- coding: utf-8 -*-
22
3-
3+ import warnings
44from typing import Union , Callable , Tuple
55
66import jax .numpy as jnp
7+ import numpy as np
78from jax import vmap
89from jax .experimental .host_callback import id_tap
910from jax .lax import cond
1011
12+ from brainpy import check
1113from brainpy import math as bm
1214from brainpy .base .base import Base
15+ from brainpy .errors import UnsupportedError
1316from brainpy .tools .checking import check_float
1417from brainpy .tools .others import to_size
15- from brainpy .errors import UnsupportedError
1618
1719__all__ = [
1820 'AbstractDelay' ,
21+ 'TimeDelay' ,
1922 'FixedLenDelay' ,
2023 'NeutralDelay' ,
2124]
@@ -32,35 +35,35 @@ def update(self, time, value):
3235_INTERP_ROUND = 'round'
3336
3437
35- class FixedLenDelay (AbstractDelay ):
36- """Delay variable which has a fixed delay length.
38+ class TimeDelay (AbstractDelay ):
39+ """Delay variable which has a fixed delay time length.
3740
3841 For example, we create a delay variable which has a maximum delay length of 1 ms
3942
4043 >>> import brainpy.math as bm
41- >>> delay = bm.FixedLenDelay (bm.zeros(3), delay_len=1., dt=0.1)
44+ >>> delay = bm.TimeDelay (bm.zeros(3), delay_len=1., dt=0.1)
4245 >>> delay(-0.5)
4346 [-0. -0. -0.]
4447
4548 This function supports multiple dimensions of the tensor. For example,
4649
4750 1. the one-dimensional delay data
4851
49- >>> delay = bm.FixedLenDelay (3, delay_len=1., dt=0.1, before_t0=lambda t: t)
52+ >>> delay = bm.TimeDelay (3, delay_len=1., dt=0.1, before_t0=lambda t: t)
5053 >>> delay(-0.2)
5154 [-0.2 -0.2 -0.2]
5255
5356 2. the two-dimensional delay data
5457
55- >>> delay = bm.FixedLenDelay ((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
58+ >>> delay = bm.TimeDelay ((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t)
5659 >>> delay(-0.6)
5760 [[-0.6 -0.6]
5861 [-0.6 -0.6]
5962 [-0.6 -0.6]]
6063
6164 3. the three-dimensional delay data
6265
63- >>> delay = bm.FixedLenDelay ((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
66+ >>> delay = bm.TimeDelay ((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t)
6467 >>> delay(-0.6)
6568 [[[-0.8]
6669 [-0.8]]
@@ -113,7 +116,7 @@ def __init__(
113116 dtype = None ,
114117 interp_method = 'linear_interp' ,
115118 ):
116- super (FixedLenDelay , self ).__init__ (name = name )
119+ super (TimeDelay , self ).__init__ (name = name )
117120
118121 # shape
119122 self .shape = to_size (shape )
@@ -161,6 +164,10 @@ def __init__(
161164 else :
162165 raise ValueError (f'"before_t0" does not support { type (before_t0 )} : before_t0' )
163166
167+ self .f = jnp .interp
168+ for dim in range (1 , len (self .shape ) + 1 , 1 ):
169+ self .f = vmap (self .f , in_axes = (None , None , dim ), out_axes = dim - 1 )
170+
164171 @property
165172 def idx (self ):
166173 return self ._idx
@@ -191,36 +198,37 @@ def current_time(self):
191198
192199 def _check_time (self , times , transforms ):
193200 prev_time , current_time = times
194- current_time = bm . as_device_array (current_time )
195- prev_time = bm . as_device_array (prev_time )
201+ current_time = np . asarray (current_time , dtype = bm . float_ )
202+ prev_time = np . asarray (prev_time , dtype = bm . float_ )
196203 if prev_time > current_time :
197204 raise ValueError (f'\n '
198205 f'!!! Error in { self .__class__ .__name__ } : \n '
199206 f'The request time should be less than the '
200207 f'current time { current_time } . But we '
201208 f'got { prev_time } > { current_time } ' )
202- lower_time = jnp .asarray (current_time - self .delay_len )
209+ lower_time = np .asarray (current_time - self .delay_len )
203210 if prev_time < lower_time :
204211 raise ValueError (f'\n '
205212 f'!!! Error in { self .__class__ .__name__ } : \n '
206213 f'The request time of the variable should be in '
207214 f'[{ lower_time } , { current_time } ], but we got { prev_time } ' )
208215
209- def __call__ (self , prev_time ):
216+ def __call__ (self , time , indices = None ):
210217 # check
211- id_tap (self ._check_time , (prev_time , self .current_time ))
218+ if check .is_checking ():
219+ id_tap (self ._check_time , (time , self .current_time ))
212220 if self ._before_type == _FUNC_BEFORE :
213- return cond (prev_time < self .t0 ,
221+ return cond (time < self .t0 ,
214222 self ._before_t0 ,
215223 self ._after_t0 ,
216- prev_time )
224+ time )
217225 else :
218- return self ._after_t0 (prev_time )
226+ return self ._after_t0 (time )
219227
220228 def _after_t0 (self , prev_time ):
221229 diff = self .delay_len - (self .current_time - prev_time )
222- if isinstance (diff , bm .ndarray ): diff = diff . value
223-
230+ if isinstance (diff , bm .ndarray ):
231+ diff = diff . value
224232 if self .interp_method == _INTERP_LINEAR :
225233 req_num_step = jnp .asarray (diff / self ._dt , dtype = bm .get_dint ())
226234 extra = diff - req_num_step * self ._dt
@@ -238,31 +246,43 @@ def _true_fn(self, div_mod):
238246
239247 def _false_fn (self , div_mod ):
240248 req_num_step , extra = div_mod
241- f = jnp .interp
242- for dim in range (1 , len (self .shape ) + 1 , 1 ):
243- f = vmap (f , in_axes = (None , None , dim ), out_axes = dim - 1 )
244249 idx = jnp .asarray ([self .idx [0 ] + req_num_step ,
245250 self .idx [0 ] + req_num_step + 1 ])
246251 idx %= self .num_delay_step
247- return f (extra , jnp .asarray ([0. , self ._dt ]), self ._data [idx ])
252+ return self . f (extra , jnp .asarray ([0. , self ._dt ]), self ._data [idx ])
248253
249254 def update (self , time , value ):
250255 self ._data [self ._idx [0 ]] = value
251256 self ._current_time [0 ] = time
252257 self ._idx .value = (self ._idx + 1 ) % self .num_delay_step
253258
254259
255- class VariedLenDelay (AbstractDelay ):
256- """Delay variable which has a functional delay
257-
258- """
260+ def FixedLenDelay (shape : Union [int , Tuple [int , ...]],
261+ delay_len : Union [float , int ],
262+ before_t0 : Union [Callable , bm .ndarray , jnp .ndarray , float , int ] = None ,
263+ t0 : Union [float , int ] = 0. ,
264+ dt : Union [float , int ] = None ,
265+ name : str = None ,
266+ dtype = None ,
267+ interp_method = 'linear_interp' , ):
268+ warnings .warn ('Please use "brainpy.math.TimeDelay" instead. '
269+ '"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ' ,
270+ DeprecationWarning )
271+ return TimeDelay (shape = shape ,
272+ delay_len = delay_len ,
273+ before_t0 = before_t0 ,
274+ t0 = t0 ,
275+ dt = dt ,
276+ name = name ,
277+ dtype = dtype ,
278+ interp_method = interp_method )
279+
280+
281+ class NeutralDelay (TimeDelay ):
282+ pass
259283
260- def update (self , time , value ):
261- pass
262284
263- def __init__ ( self ):
264- super ( VariedLenDelay , self ). __init__ ()
285+ class LengthDelay ( AbstractDelay ):
286+ pass
265287
266288
267- class NeutralDelay (FixedLenDelay ):
268- pass
0 commit comments