@@ -280,12 +280,24 @@ class LengthDelay(AbstractDelay):
280280 It can also be arrays. Or a callable function or instance of ``Connector``.
281281 Note that ``initial_delay_data`` should be arranged as the following way::
282282
283- delay = delay_len [ data
284- delay = delay_len-1 data
283+ delay = 1 [ data
284+ delay = 2 data
285285 ... ....
286286 ... ....
287- delay = 2 data
288- delay = 1 data ]
287+ delay = delay_len-1 data
288+ delay = delay_len data ]
289+
290+ .. versionchanged:: 2.2.3.2
291+
292+ The data in the previous version of ``LengthDelay`` is::
293+
294+ delay = delay_len [ data
295+ delay = delay_len-1 data
296+ ... ....
297+ ... ....
298+ delay = 2 data
299+ delay = 1 data ]
300+
289301
290302 name: str
291303 The delay object name.
@@ -368,13 +380,13 @@ def reset(
368380 dtype = delay_target .dtype )
369381
370382 # update delay data
371- self .data [- 1 ] = delay_target
383+ self .data [0 ] = delay_target
372384 if initial_delay_data is None :
373385 pass
374386 elif isinstance (initial_delay_data , (ndarray , jnp .ndarray , float , int , bool )):
375- self .data [: - 1 ] = initial_delay_data
387+ self .data [1 : ] = initial_delay_data
376388 elif callable (initial_delay_data ):
377- self .data [: - 1 ] = initial_delay_data ((delay_len ,) + delay_target .shape ,
389+ self .data [1 : ] = initial_delay_data ((delay_len ,) + delay_target .shape ,
378390 dtype = delay_target .dtype )
379391 else :
380392 raise ValueError (f'"delay_data" does not support { type (initial_delay_data )} ' )
@@ -406,20 +418,22 @@ def retrieve(self, delay_len, *indices):
406418 check_error_in_jit (bm .any (delay_len >= self .num_delay_step ), self ._check_delay , delay_len )
407419
408420 if self .update_method == ROTATION_UPDATING :
409- # the delay length
410- delay_idx = (self .idx [0 ] - delay_len - 1 ) % self .num_delay_step
421+ delay_idx = (self .idx [0 ] + delay_len ) % self .num_delay_step
411422 delay_idx = stop_gradient (delay_idx )
412- if not jnp .issubdtype (delay_idx .dtype , jnp .integer ):
413- raise ValueError (f'"delay_len" must be integer, but we got { delay_len } ' )
414423
415424 elif self .update_method == CONCAT_UPDATING :
416- delay_idx = self . num_delay_step - 1 - delay_len
425+ delay_idx = delay_len
417426
418427 else :
419428 raise ValueError (f'Unknown updating method "{ self .update_method } "' )
420429
421- # the delay data
430+ # the delay index
431+ if isinstance (delay_idx , int ):
432+ pass
433+ elif hasattr (delay_idx , 'dtype' ) and not jnp .issubdtype (delay_idx .dtype , jnp .integer ):
434+ raise ValueError (f'"delay_len" must be integer, but we got { delay_idx } ' )
422435 indices = (delay_idx ,) + tuple (indices )
436+ # the delay data
423437 return self .data [indices ]
424438
425439 def update (self , value : Union [float , int , bool , JaxArray , jnp .DeviceArray ]):
@@ -435,7 +449,10 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
435449 self .idx .value = stop_gradient ((self .idx + 1 ) % self .num_delay_step )
436450
437451 elif self .update_method == CONCAT_UPDATING :
438- self .data .value = bm .vstack ([self .data [1 :], bm .broadcast_to (value ,self .data .shape [1 :])])
452+ if self .num_delay_step >= 2 :
453+ self .data .value = bm .vstack ([bm .broadcast_to (value , self .data .shape [1 :]), self .data [1 :]])
454+ else :
455+ self .data [:] = value
439456
440457 else :
441458 raise ValueError (f'Unknown updating method "{ self .update_method } "' )
0 commit comments