@@ -47,9 +47,16 @@ class DynamicalSystem(Base):
4747 The name of the dynamic system.
4848 """
4949
50+ """Global delay variables. Useful when the same target
51+ variable is used in multiple mappings."""
52+ global_delay_vars : Dict [str , bm .LengthDelay ] = dict ()
53+
5054 def __init__ (self , name = None ):
5155 super (DynamicalSystem , self ).__init__ (name = name )
5256
57+ # local delay variables
58+ self .local_delay_vars : Dict [str , bm .LengthDelay ] = dict ()
59+
5360 @property
5461 def steps (self ):
5562 warnings .warn ('.steps has been deprecated since version 2.0.3.' , DeprecationWarning )
@@ -81,6 +88,149 @@ def __call__(self, *args, **kwargs):
8188 """The shortcut to call ``update`` methods."""
8289 return self .update (* args , ** kwargs )
8390
91+ def register_delay (
92+ self ,
93+ name : str ,
94+ delay_step : Union [int , Tensor , Callable , Initializer ],
95+ delay_target : Union [bm .JaxArray , jnp .ndarray ],
96+ initial_delay_data : Union [Initializer , Callable , Tensor , float , int , bool ] = None ,
97+ domain : str = 'global'
98+ ):
99+ """Register delay variable.
100+
101+ Parameters
102+ ----------
103+ name: str
104+ The delay variable name.
105+ delay_step: int, JaxArray, ndarray, callable, Initializer
106+ The number of the steps of the delay.
107+ delay_target: JaxArray, ndarray, Variable
108+ The target for delay.
109+ initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
110+ The initializer for the delay data.
111+ domain: str
112+ The domain of the delay data to store.
113+
114+ Returns
115+ -------
116+ delay_step: int, JaxArray, ndarray
117+ The number of the delay steps.
118+ """
119+ # delay steps
120+ if delay_step is None :
121+ return delay_step
122+ elif isinstance (delay_step , int ):
123+ delay_type = 'homo'
124+ elif isinstance (delay_step , (bm .ndarray , jnp .ndarray , np .ndarray )):
125+ delay_type = 'heter'
126+ delay_step = bm .asarray (delay_step )
127+ elif callable (delay_step ):
128+ delay_step = init_param (delay_step , delay_target .shape , allow_none = False )
129+ delay_type = 'heter'
130+ else :
131+ raise ValueError (f'Unknown "delay_steps" type { type (delay_step )} , only support '
132+ f'integer, array of integers, callable function, brainpy.init.Initializer.' )
133+ if delay_type == 'heter' :
134+ if delay_step .dtype not in [bm .int32 , bm .int64 ]:
135+ raise ValueError ('Only support delay steps of int32, int64. If your '
136+ 'provide delay time length, please divide the "dt" '
137+ 'then provide us the number of delay steps.' )
138+ if delay_target .shape [0 ] != delay_step .shape [0 ]:
139+ raise ValueError (f'Shape is mismatched: { delay_target .shape [0 ]} != { delay_step .shape [0 ]} ' )
140+ max_delay_step = int (bm .max (delay_step ))
141+
142+ # delay domain
143+ if domain not in ['global' , 'local' ]:
144+ raise ValueError ('"domain" must be a string in ["global", "local"]. '
145+ f'Bug we got { domain } .' )
146+
147+ # delay variable
148+ if domain == 'local' :
149+ self .local_delay_vars [name ] = bm .LengthDelay (delay_target , max_delay_step , initial_delay_data )
150+ self .register_implicit_nodes (self .local_delay_vars )
151+ else :
152+ if name not in self .global_delay_vars :
153+ self .global_delay_vars [name ] = bm .LengthDelay (delay_target , max_delay_step , initial_delay_data )
154+ # save into local delay vars when first seen "var",
155+ # for later update current value!
156+ self .local_delay_vars [name ] = self .global_delay_vars [name ]
157+ else :
158+ if self .global_delay_vars [name ].num_delay_step - 1 < max_delay_step :
159+ self .global_delay_vars [name ].reset (delay_target , max_delay_step , initial_delay_data )
160+ self .register_implicit_nodes (self .global_delay_vars )
161+ return delay_step
162+
163+ def get_delay_data (
164+ self ,
165+ name : str ,
166+ delay_step : Union [int , bm .JaxArray , jnp .DeviceArray ],
167+ indices : Union [int , bm .JaxArray , jnp .DeviceArray ] = None ,
168+ ):
169+ """Get delay data according to the provided delay steps.
170+
171+ Parameters
172+ ----------
173+ name: str
174+ The delay variable name.
175+ delay_step: int, JaxArray, ndarray
176+ The delay length.
177+ indices: optional, int, JaxArray, ndarray
178+ The indices of the delay.
179+
180+ Returns
181+ -------
182+ delay_data: JaxArray, ndarray
183+ The delay data at the given time.
184+ """
185+ if name in self .global_delay_vars :
186+ if isinstance (delay_step , int ):
187+ return self .global_delay_vars [name ](delay_step , indices )
188+ else :
189+ if indices is None :
190+ indices = jnp .arange (delay_step .size )
191+ return self .global_delay_vars [name ](delay_step , indices )
192+ elif name in self .local_delay_vars :
193+ if isinstance (delay_step , int ):
194+ return self .local_delay_vars [name ](delay_step )
195+ else :
196+ if indices is None :
197+ indices = jnp .arange (delay_step .size )
198+ return self .local_delay_vars [name ](delay_step , indices )
199+ else :
200+ raise ValueError (f'{ name } is not defined in delay variables.' )
201+
202+ def update_delay (
203+ self ,
204+ name : str ,
205+ delay_data : Union [float , bm .JaxArray , jnp .ndarray ]
206+ ):
207+ """Update the delay according to the delay data.
208+
209+ Parameters
210+ ----------
211+ name: str
212+ The name of the delay.
213+ delay_data: float, JaxArray, ndarray
214+ The delay data to update at the current time.
215+ """
216+ if name in self .local_delay_vars :
217+ return self .local_delay_vars [name ].update (delay_data )
218+ else :
219+ if name not in self .global_delay_vars :
220+ raise ValueError (f'{ name } is not defined in delay variables.' )
221+
222+ def reset_delay (
223+ self ,
224+ name : str ,
225+ delay_target : Union [bm .JaxArray , jnp .DeviceArray ]
226+ ):
227+ """Reset the delay variable."""
228+ if name in self .local_delay_vars :
229+ return self .local_delay_vars [name ].reset (delay_target )
230+ else :
231+ if name not in self .global_delay_vars :
232+ raise ValueError (f'{ name } is not defined in delay variables.' )
233+
84234 def update (self , _t , _dt ):
85235 """The function to specify the updating rule.
86236 Assume any dynamical system depends on the time variable ``t`` and
@@ -356,19 +506,13 @@ class TwoEndConn(DynamicalSystem):
356506 The name of the dynamic system.
357507 """
358508
359- """Global delay variables. Useful when the same target
360- variable is used in multiple mappings."""
361- global_delay_vars : Dict [str , bm .LengthDelay ] = dict ()
362-
363509 def __init__ (
364510 self ,
365511 pre : NeuGroup ,
366512 post : NeuGroup ,
367513 conn : Union [TwoEndConnector , Tensor , Dict [str , Tensor ]] = None ,
368514 name : str = None
369515 ):
370- # local delay variables
371- self .local_delay_vars : Dict [str , bm .LengthDelay ] = dict ()
372516
373517 # pre or post neuron group
374518 # ------------------------
@@ -425,146 +569,3 @@ def check_post_attrs(self, *attrs):
425569 raise ValueError (f'Must be string. But got { attr } .' )
426570 if not hasattr (self .post , attr ):
427571 raise ModelBuildError (f'{ self } need "pre" neuron group has attribute "{ attr } ".' )
428-
429- def register_delay (
430- self ,
431- name : str ,
432- delay_step : Union [int , bm .ndarray , jnp .ndarray , Callable , Initializer ],
433- delay_target : Union [bm .JaxArray , jnp .ndarray ],
434- initial_delay_data : Union [Initializer , Callable ] = None ,
435- domain : str = 'global'
436- ):
437- """Register delay variable.
438-
439- Parameters
440- ----------
441- name: str
442- The delay variable name.
443- delay_step: int, JaxArray, ndarray, callable, Initializer
444- The number of the steps of the delay.
445- delay_target: JaxArray, ndarray, Variable
446- The target for delay.
447- initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
448- The initializer for the delay data.
449- domain: str
450- The domain of the delay data to store.
451-
452- Returns
453- -------
454- delay_step: int, JaxArray, ndarray
455- The number of the delay steps.
456- """
457- # delay steps
458- if delay_step is None :
459- return delay_step
460- elif isinstance (delay_step , int ):
461- delay_type = 'homo'
462- elif isinstance (delay_step , (bm .ndarray , jnp .ndarray , np .ndarray )):
463- delay_type = 'heter'
464- delay_step = bm .asarray (delay_step )
465- elif callable (delay_step ):
466- delay_step = init_param (delay_step , delay_target .shape , allow_none = False )
467- delay_type = 'heter'
468- else :
469- raise ValueError (f'Unknown "delay_steps" type { type (delay_step )} , only support '
470- f'integer, array of integers, callable function, brainpy.init.Initializer.' )
471- if delay_type == 'heter' :
472- if delay_step .dtype not in [bm .int32 , bm .int64 ]:
473- raise ValueError ('Only support delay steps of int32, int64. If your '
474- 'provide delay time length, please divide the "dt" '
475- 'then provide us the number of delay steps.' )
476- if delay_target .shape [0 ] != delay_step .shape [0 ]:
477- raise ValueError (f'Shape is mismatched: { delay_target .shape [0 ]} != { delay_step .shape [0 ]} ' )
478- max_delay_step = int (bm .max (delay_step ))
479-
480- # delay domain
481- if domain not in ['global' , 'local' ]:
482- raise ValueError ('"domain" must be a string in ["global", "local"]. '
483- f'Bug we got { domain } .' )
484-
485- # delay variable
486- if domain == 'local' :
487- self .local_delay_vars [name ] = bm .LengthDelay (delay_target , max_delay_step , initial_delay_data )
488- self .register_implicit_nodes (self .local_delay_vars )
489- else :
490- if name not in self .global_delay_vars :
491- self .global_delay_vars [name ] = bm .LengthDelay (delay_target , max_delay_step , initial_delay_data )
492- # save into local delay vars when first seen "var",
493- # for later update current value!
494- self .local_delay_vars [name ] = self .global_delay_vars [name ]
495- else :
496- if self .global_delay_vars [name ].num_delay_step - 1 < max_delay_step :
497- self .global_delay_vars [name ].reset (delay_target , max_delay_step , initial_delay_data )
498- self .register_implicit_nodes (self .global_delay_vars )
499- return delay_step
500-
501- def get_delay_data (
502- self ,
503- name : str ,
504- delay_step : Union [int , bm .JaxArray , jnp .DeviceArray ],
505- indices : Union [int , bm .JaxArray , jnp .DeviceArray ] = None ,
506- ):
507- """Get delay data according to the provided delay steps.
508-
509- Parameters
510- ----------
511- name: str
512- The delay variable name.
513- delay_step: int, JaxArray, ndarray
514- The delay length.
515- indices: optional, int, JaxArray, ndarray
516- The indices of the delay.
517-
518- Returns
519- -------
520- delay_data: JaxArray, ndarray
521- The delay data at the given time.
522- """
523- if name in self .global_delay_vars :
524- if isinstance (delay_step , int ):
525- return self .global_delay_vars [name ](delay_step , indices )
526- else :
527- if indices is None :
528- indices = jnp .arange (delay_step .size )
529- return self .global_delay_vars [name ](delay_step , indices )
530- elif name in self .local_delay_vars :
531- if isinstance (delay_step , int ):
532- return self .local_delay_vars [name ](delay_step )
533- else :
534- if indices is None :
535- indices = jnp .arange (delay_step .size )
536- return self .local_delay_vars [name ](delay_step , indices )
537- else :
538- raise ValueError (f'{ name } is not defined in delay variables.' )
539-
540- def update_delay (
541- self ,
542- name : str ,
543- delay_data : Union [float , bm .JaxArray , jnp .ndarray ]
544- ):
545- """Update the delay according to the delay data.
546-
547- Parameters
548- ----------
549- name: str
550- The name of the delay.
551- delay_data: float, JaxArray, ndarray
552- The delay data to update at the current time.
553- """
554- if name in self .local_delay_vars :
555- return self .local_delay_vars [name ].update (delay_data )
556- else :
557- if name not in self .global_delay_vars :
558- raise ValueError (f'{ name } is not defined in delay variables.' )
559-
560- def reset_delay (
561- self ,
562- name : str ,
563- delay_target : Union [bm .JaxArray , jnp .DeviceArray ]
564- ):
565- """Reset the delay variable."""
566- if name in self .local_delay_vars :
567- return self .local_delay_vars [name ].reset (delay_target )
568- else :
569- if name not in self .global_delay_vars :
570- raise ValueError (f'{ name } is not defined in delay variables.' )
0 commit comments