Skip to content

Commit a04579c

Browse files
authored
feat: update delay couplings of DiffusiveCoupling and AdditiveCouping (#190)
feat: update delay couplings of `DiffusiveCoupling` and `AdditiveCouping`
2 parents fee5d2d + 2d72883 commit a04579c

File tree

9 files changed

+461
-342
lines changed

9 files changed

+461
-342
lines changed

brainpy/analysis/lowdim/lowdim_analyzer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ def __init__(self,
148148
_target_vp = self.target_vars + self.target_pars
149149
if resolutions is None:
150150
for key, lim in self.target_vars.items():
151-
self.resolutions[key] = bm.asarray(np.linspace(*lim, 20))
151+
self.resolutions[key] = bm.linspace(*lim, 20)
152152
for key, lim in self.target_pars.items():
153-
self.resolutions[key] = bm.asarray(np.linspace(*lim, 20))
153+
self.resolutions[key] = bm.linspace(*lim, 20)
154154
elif isinstance(resolutions, float):
155155
for key, lim in self.target_vars.items():
156-
self.resolutions[key] = bm.asarray(np.arange(*lim, resolutions))
156+
self.resolutions[key] = bm.arange(*lim, resolutions)
157157
for key, lim in self.target_pars.items():
158-
self.resolutions[key] = bm.asarray(np.arange(*lim, resolutions))
158+
self.resolutions[key] = bm.arange(*lim, resolutions)
159159
elif isinstance(resolutions, dict):
160160
for key in resolutions.keys():
161161
if key in self.target_var_names:
@@ -167,11 +167,11 @@ def __init__(self,
167167
f'the target parameters {self.target_par_names}.')
168168
for key in self.target_var_names + self.target_par_names:
169169
if key not in resolutions:
170-
self.resolutions[key] = bm.asarray(np.linspace(*_target_vp[key], 20))
170+
self.resolutions[key] = bm.linspace(*_target_vp[key], 20)
171171
else:
172172
resolution = resolutions[key]
173173
if isinstance(resolution, float):
174-
self.resolutions[key] = bm.asarray(np.arange(*_target_vp[key], resolution))
174+
self.resolutions[key] = bm.arange(*_target_vp[key], resolution)
175175
elif isinstance(resolution, (bm.ndarray, np.ndarray, jnp.ndarray)):
176176
if not np.ndim(resolution) == 1:
177177
raise errors.AnalyzerError(f'resolution must be a 1D array, but get its '

brainpy/dyn/base.py

Lines changed: 150 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,16 @@ class DynamicalSystem(Base):
4747
The name of the dynamic system.
4848
"""
4949

50+
"""Global delay variables. Useful when the same target
51+
variable is used in multiple mappings."""
52+
global_delay_vars: Dict[str, bm.LengthDelay] = dict()
53+
5054
def __init__(self, name=None):
5155
super(DynamicalSystem, self).__init__(name=name)
5256

57+
# local delay variables
58+
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()
59+
5360
@property
5461
def steps(self):
5562
warnings.warn('.steps has been deprecated since version 2.0.3.', DeprecationWarning)
@@ -81,6 +88,149 @@ def __call__(self, *args, **kwargs):
8188
"""The shortcut to call ``update`` methods."""
8289
return self.update(*args, **kwargs)
8390

91+
def register_delay(
92+
self,
93+
name: str,
94+
delay_step: Union[int, Tensor, Callable, Initializer],
95+
delay_target: Union[bm.JaxArray, jnp.ndarray],
96+
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
97+
domain: str = 'global'
98+
):
99+
"""Register delay variable.
100+
101+
Parameters
102+
----------
103+
name: str
104+
The delay variable name.
105+
delay_step: int, JaxArray, ndarray, callable, Initializer
106+
The number of the steps of the delay.
107+
delay_target: JaxArray, ndarray, Variable
108+
The target for delay.
109+
initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
110+
The initializer for the delay data.
111+
domain: str
112+
The domain of the delay data to store.
113+
114+
Returns
115+
-------
116+
delay_step: int, JaxArray, ndarray
117+
The number of the delay steps.
118+
"""
119+
# delay steps
120+
if delay_step is None:
121+
return delay_step
122+
elif isinstance(delay_step, int):
123+
delay_type = 'homo'
124+
elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
125+
delay_type = 'heter'
126+
delay_step = bm.asarray(delay_step)
127+
elif callable(delay_step):
128+
delay_step = init_param(delay_step, delay_target.shape, allow_none=False)
129+
delay_type = 'heter'
130+
else:
131+
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
132+
f'integer, array of integers, callable function, brainpy.init.Initializer.')
133+
if delay_type == 'heter':
134+
if delay_step.dtype not in [bm.int32, bm.int64]:
135+
raise ValueError('Only support delay steps of int32, int64. If your '
136+
'provide delay time length, please divide the "dt" '
137+
'then provide us the number of delay steps.')
138+
if delay_target.shape[0] != delay_step.shape[0]:
139+
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
140+
max_delay_step = int(bm.max(delay_step))
141+
142+
# delay domain
143+
if domain not in ['global', 'local']:
144+
raise ValueError('"domain" must be a string in ["global", "local"]. '
145+
f'Bug we got {domain}.')
146+
147+
# delay variable
148+
if domain == 'local':
149+
self.local_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
150+
self.register_implicit_nodes(self.local_delay_vars)
151+
else:
152+
if name not in self.global_delay_vars:
153+
self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
154+
# save into local delay vars when first seen "var",
155+
# for later update current value!
156+
self.local_delay_vars[name] = self.global_delay_vars[name]
157+
else:
158+
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
159+
self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data)
160+
self.register_implicit_nodes(self.global_delay_vars)
161+
return delay_step
162+
163+
def get_delay_data(
164+
self,
165+
name: str,
166+
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
167+
indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None,
168+
):
169+
"""Get delay data according to the provided delay steps.
170+
171+
Parameters
172+
----------
173+
name: str
174+
The delay variable name.
175+
delay_step: int, JaxArray, ndarray
176+
The delay length.
177+
indices: optional, int, JaxArray, ndarray
178+
The indices of the delay.
179+
180+
Returns
181+
-------
182+
delay_data: JaxArray, ndarray
183+
The delay data at the given time.
184+
"""
185+
if name in self.global_delay_vars:
186+
if isinstance(delay_step, int):
187+
return self.global_delay_vars[name](delay_step, indices)
188+
else:
189+
if indices is None:
190+
indices = jnp.arange(delay_step.size)
191+
return self.global_delay_vars[name](delay_step, indices)
192+
elif name in self.local_delay_vars:
193+
if isinstance(delay_step, int):
194+
return self.local_delay_vars[name](delay_step)
195+
else:
196+
if indices is None:
197+
indices = jnp.arange(delay_step.size)
198+
return self.local_delay_vars[name](delay_step, indices)
199+
else:
200+
raise ValueError(f'{name} is not defined in delay variables.')
201+
202+
def update_delay(
203+
self,
204+
name: str,
205+
delay_data: Union[float, bm.JaxArray, jnp.ndarray]
206+
):
207+
"""Update the delay according to the delay data.
208+
209+
Parameters
210+
----------
211+
name: str
212+
The name of the delay.
213+
delay_data: float, JaxArray, ndarray
214+
The delay data to update at the current time.
215+
"""
216+
if name in self.local_delay_vars:
217+
return self.local_delay_vars[name].update(delay_data)
218+
else:
219+
if name not in self.global_delay_vars:
220+
raise ValueError(f'{name} is not defined in delay variables.')
221+
222+
def reset_delay(
223+
self,
224+
name: str,
225+
delay_target: Union[bm.JaxArray, jnp.DeviceArray]
226+
):
227+
"""Reset the delay variable."""
228+
if name in self.local_delay_vars:
229+
return self.local_delay_vars[name].reset(delay_target)
230+
else:
231+
if name not in self.global_delay_vars:
232+
raise ValueError(f'{name} is not defined in delay variables.')
233+
84234
def update(self, _t, _dt):
85235
"""The function to specify the updating rule.
86236
Assume any dynamical system depends on the time variable ``t`` and
@@ -356,19 +506,13 @@ class TwoEndConn(DynamicalSystem):
356506
The name of the dynamic system.
357507
"""
358508

359-
"""Global delay variables. Useful when the same target
360-
variable is used in multiple mappings."""
361-
global_delay_vars: Dict[str, bm.LengthDelay] = dict()
362-
363509
def __init__(
364510
self,
365511
pre: NeuGroup,
366512
post: NeuGroup,
367513
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None,
368514
name: str = None
369515
):
370-
# local delay variables
371-
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()
372516

373517
# pre or post neuron group
374518
# ------------------------
@@ -425,146 +569,3 @@ def check_post_attrs(self, *attrs):
425569
raise ValueError(f'Must be string. But got {attr}.')
426570
if not hasattr(self.post, attr):
427571
raise ModelBuildError(f'{self} need "pre" neuron group has attribute "{attr}".')
428-
429-
def register_delay(
430-
self,
431-
name: str,
432-
delay_step: Union[int, bm.ndarray, jnp.ndarray, Callable, Initializer],
433-
delay_target: Union[bm.JaxArray, jnp.ndarray],
434-
initial_delay_data: Union[Initializer, Callable] = None,
435-
domain: str = 'global'
436-
):
437-
"""Register delay variable.
438-
439-
Parameters
440-
----------
441-
name: str
442-
The delay variable name.
443-
delay_step: int, JaxArray, ndarray, callable, Initializer
444-
The number of the steps of the delay.
445-
delay_target: JaxArray, ndarray, Variable
446-
The target for delay.
447-
initial_delay_data: float, int, JaxArray, ndarray, callable, Initializer
448-
The initializer for the delay data.
449-
domain: str
450-
The domain of the delay data to store.
451-
452-
Returns
453-
-------
454-
delay_step: int, JaxArray, ndarray
455-
The number of the delay steps.
456-
"""
457-
# delay steps
458-
if delay_step is None:
459-
return delay_step
460-
elif isinstance(delay_step, int):
461-
delay_type = 'homo'
462-
elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
463-
delay_type = 'heter'
464-
delay_step = bm.asarray(delay_step)
465-
elif callable(delay_step):
466-
delay_step = init_param(delay_step, delay_target.shape, allow_none=False)
467-
delay_type = 'heter'
468-
else:
469-
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
470-
f'integer, array of integers, callable function, brainpy.init.Initializer.')
471-
if delay_type == 'heter':
472-
if delay_step.dtype not in [bm.int32, bm.int64]:
473-
raise ValueError('Only support delay steps of int32, int64. If your '
474-
'provide delay time length, please divide the "dt" '
475-
'then provide us the number of delay steps.')
476-
if delay_target.shape[0] != delay_step.shape[0]:
477-
raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
478-
max_delay_step = int(bm.max(delay_step))
479-
480-
# delay domain
481-
if domain not in ['global', 'local']:
482-
raise ValueError('"domain" must be a string in ["global", "local"]. '
483-
f'Bug we got {domain}.')
484-
485-
# delay variable
486-
if domain == 'local':
487-
self.local_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
488-
self.register_implicit_nodes(self.local_delay_vars)
489-
else:
490-
if name not in self.global_delay_vars:
491-
self.global_delay_vars[name] = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
492-
# save into local delay vars when first seen "var",
493-
# for later update current value!
494-
self.local_delay_vars[name] = self.global_delay_vars[name]
495-
else:
496-
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
497-
self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data)
498-
self.register_implicit_nodes(self.global_delay_vars)
499-
return delay_step
500-
501-
def get_delay_data(
502-
self,
503-
name: str,
504-
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
505-
indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None,
506-
):
507-
"""Get delay data according to the provided delay steps.
508-
509-
Parameters
510-
----------
511-
name: str
512-
The delay variable name.
513-
delay_step: int, JaxArray, ndarray
514-
The delay length.
515-
indices: optional, int, JaxArray, ndarray
516-
The indices of the delay.
517-
518-
Returns
519-
-------
520-
delay_data: JaxArray, ndarray
521-
The delay data at the given time.
522-
"""
523-
if name in self.global_delay_vars:
524-
if isinstance(delay_step, int):
525-
return self.global_delay_vars[name](delay_step, indices)
526-
else:
527-
if indices is None:
528-
indices = jnp.arange(delay_step.size)
529-
return self.global_delay_vars[name](delay_step, indices)
530-
elif name in self.local_delay_vars:
531-
if isinstance(delay_step, int):
532-
return self.local_delay_vars[name](delay_step)
533-
else:
534-
if indices is None:
535-
indices = jnp.arange(delay_step.size)
536-
return self.local_delay_vars[name](delay_step, indices)
537-
else:
538-
raise ValueError(f'{name} is not defined in delay variables.')
539-
540-
def update_delay(
541-
self,
542-
name: str,
543-
delay_data: Union[float, bm.JaxArray, jnp.ndarray]
544-
):
545-
"""Update the delay according to the delay data.
546-
547-
Parameters
548-
----------
549-
name: str
550-
The name of the delay.
551-
delay_data: float, JaxArray, ndarray
552-
The delay data to update at the current time.
553-
"""
554-
if name in self.local_delay_vars:
555-
return self.local_delay_vars[name].update(delay_data)
556-
else:
557-
if name not in self.global_delay_vars:
558-
raise ValueError(f'{name} is not defined in delay variables.')
559-
560-
def reset_delay(
561-
self,
562-
name: str,
563-
delay_target: Union[bm.JaxArray, jnp.DeviceArray]
564-
):
565-
"""Reset the delay variable."""
566-
if name in self.local_delay_vars:
567-
return self.local_delay_vars[name].reset(delay_target)
568-
else:
569-
if name not in self.global_delay_vars:
570-
raise ValueError(f'{name} is not defined in delay variables.')

brainpy/dyn/neurons/rate_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
# other parameters
9090
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
9191
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05),
92-
method: str = None,
92+
method: str = 'exp_auto',
9393
sde_method: str = None,
9494
name: str = None,
9595
):
@@ -556,7 +556,7 @@ def __init__(
556556
# other parameters
557557
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
558558
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5),
559-
method: str = None,
559+
method: str = 'exp_auto',
560560
sde_method: str = None,
561561
name: str = None,
562562
):

0 commit comments

Comments
 (0)