1414
1515__all__ = [
1616 'LoopOverTime' ,
17- 'NoSharedArg' ,
1817]
1918
2019
@@ -207,12 +206,13 @@ def __call__(
207206 if isinstance (duration_or_xs , float ):
208207 shared = tools .DotDict ()
209208 if self .t0 is not None :
210- shared ['t' ] = jnp .arange (self . t0 . value , duration_or_xs , self .dt )
209+ shared ['t' ] = jnp .arange (0 , duration_or_xs , self .dt ) + self . t0 . value
211210 if self .i0 is not None :
212- shared ['i' ] = jnp .arange (self . i0 . value , shared ['t' ].shape [0 ])
211+ shared ['i' ] = jnp .arange (0 , shared ['t' ].shape [0 ]) + self . i0 . value
213212 xs = None
214213 if self .no_state :
215214 raise ValueError ('Under the `no_state=True` setting, input cannot be a duration.' )
215+ length = shared ['t' ].shape
216216
217217 else :
218218 inp_err_msg = ('\n '
@@ -278,8 +278,8 @@ def __call__(
278278
279279 else :
280280 shared = tools .DotDict ()
281- shared ['t' ] = jnp .arange (self . t0 . value , self .dt * length [0 ], self .dt )
282- shared ['i' ] = jnp .arange (self . i0 . value , length [0 ])
281+ shared ['t' ] = jnp .arange (0 , self .dt * length [0 ], self .dt ) + self . t0 . value
282+ shared ['i' ] = jnp .arange (0 , length [0 ]) + self . i0 . value
283283
284284 assert not self .no_state
285285 results = bm .for_loop (functools .partial (self ._run , self .shared_arg ),
@@ -295,6 +295,10 @@ def __call__(
295295
296296 def reset_state (self , batch_size = None ):
297297 self .target .reset_state (batch_size )
298+ if self .i0 is not None :
299+ self .i0 .value = jnp .asarray (0 )
300+ if self .t0 is not None :
301+ self .t0 .value = jnp .asarray (0. )
298302
299303 def _run (self , static_sh , dyn_sh , x ):
300304 share .save (** static_sh , ** dyn_sh )
@@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x):
304308 self .target .clear_input ()
305309 return outs
306310
307-
308- class NoSharedArg (DynSysToBPObj ):
309- """Transform an instance of :py:class:`~.DynamicalSystem` into a callable
310- :py:class:`~.BrainPyObject` :math:`y=f(x)`.
311-
312- .. note::
313-
314- This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.
315-
316- If some children nodes need shared arguments, like :py:class:`~.Dropout` or
317- :py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors.
318-
319- Examples
320- --------
321-
322- >>> import brainpy as bp
323- >>> import brainpy.math as bm
324- >>> l = bp.Sequential(bp.layers.Dense(100, 10),
325- >>> bm.relu,
326- >>> bp.layers.Dense(10, 2))
327- >>> l = bp.NoSharedArg(l)
328- >>> l(bm.random.random(256, 100))
329-
330- Parameters
331- ----------
332- target: DynamicalSystem
333- The target to transform.
334- name: str
335- The transformed object name.
336- """
337-
338- def __init__ (self , target : DynamicalSystem , name : str = None ):
339- super ().__init__ (target = target , name = name )
340- if isinstance (target , Sequential ) and target .no_shared_arg :
341- raise ValueError (f'It is a { Sequential .__name__ } object with `no_shared_arg=True`, '
342- f'which has already able to be called with `f(x)`. ' )
343-
344- def __call__ (self , * args , ** kwargs ):
345- return self .target (tools .DotDict (), * args , ** kwargs )
346-
347- def reset (self , batch_size = None ):
348- """Reset function which reset the whole variables in the model.
349- """
350- self .target .reset (batch_size )
351-
352- def reset_state (self , batch_size = None ):
353- self .target .reset_state (batch_size )
0 commit comments