Skip to content

Commit 176565f

Browse files
authored
Merge pull request #339 from chaoming0625/share
``brainpy.math.share`` as the global context for sharing data across all modules/nodes
2 parents 6cb4b86 + 96a0652 commit 176565f

File tree

17 files changed

+272
-466
lines changed

17 files changed

+272
-466
lines changed

brainpy/_src/dyn/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
135135

136136
# super initialization
137-
super(DynamicalSystem, self).__init__(name=name)
137+
BrainPyObject.__init__(self, name=name)
138138

139139
@property
140140
def mode(self) -> bm.Mode:
@@ -155,7 +155,8 @@ def __call__(self, *args, **kwargs):
155155
"""The shortcut to call ``update`` methods."""
156156
if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
157157
if len(args) and isinstance(args[0], dict):
158-
bm.share.save_shargs(**args[0])
158+
for k, v in args[0].items():
159+
bm.share.save(k, v)
159160
return self.update(*args[1:], **kwargs)
160161
else:
161162
return self.update(*args, **kwargs)

brainpy/_src/dyn/runners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,8 @@ def _step_func_predict(self, shared_args, t, i, x):
615615
# input step
616616
shared = tools.DotDict(t=t, i=i, dt=self.dt)
617617
shared.update(shared_args)
618-
bm.share.save_shargs(**shared)
618+
for k, v in shared.items():
619+
bm.share.save(k, v)
619620
self.target.clear_input()
620621
self._step_func_input(shared)
621622

@@ -630,7 +631,6 @@ def _step_func_predict(self, shared_args, t, i, x):
630631
# finally
631632
if self.progress_bar:
632633
id_tap(lambda *arg: self._pbar.update(), ())
633-
bm.share.remove_shargs()
634634
return out, mon
635635

636636
def _get_f_predict(self, shared_args: Dict = None):

brainpy/_src/experimental/delay.py

Lines changed: 20 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -1,300 +1,47 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Union, Callable, Optional, Tuple, Sequence, Dict
3+
from typing import Union, Callable, Optional, Dict
44

55
import jax
6-
import jax.numpy as jnp
7-
import numpy as np
8-
from jax.lax import stop_gradient
96

10-
from brainpy import check, math as bm
11-
from brainpy._src.math.object_transform.base import Collector
7+
from brainpy import math as bm
128
from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs
13-
from brainpy.check import is_integer, jit_error_checking
9+
from brainpy._src.math.delayvars import DelayVariable, ROTATE_UPDATE, CONCAT_UPDATE
1410

15-
ROTATE_UPDATE = 'rotation'
16-
CONCAT_UPDATE = 'concat'
1711

12+
class Delay(DynamicalSystem, DelayVariable):
13+
"""Delay for dynamical systems which has a fixed delay length.
1814
19-
class Delay(DynamicalSystem):
20-
"""Delay variable which has a fixed delay length.
21-
22-
The data in this delay variable is arranged as::
23-
24-
delay = 0 [ data
25-
delay = 1 data
26-
delay = 2 data
27-
... ....
28-
... ....
29-
delay = length-1 data
30-
delay = length data ]
31-
32-
Parameters
33-
----------
34-
target: Variable
35-
The initial delay data.
36-
length: int
37-
The delay data length.
38-
initial_delay_data: Any
39-
The delay data. It can be a Python number, like float, int, boolean values.
40-
It can also be arrays. Or a callable function or instance of ``Connector``.
41-
Note that ``initial_delay_data`` should be arranged as the following way::
42-
43-
delay = 1 [ data
44-
delay = 2 data
45-
... ....
46-
... ....
47-
delay = length-1 data
48-
delay = length data ]
49-
method: str
50-
The method used for updating delay.
51-
15+
Detailed docstring please see :py:class:`~.DelayVariable`.
5216
"""
5317

54-
data: Optional[bm.Variable]
55-
idx: Optional[bm.Variable]
56-
length: int
57-
5818
def __init__(
5919
self,
6020
target: bm.Variable,
6121
length: int = 0,
62-
initial_delay_data: Union[float, int, bool, bm.Array, jax.Array, Callable] = None,
22+
before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None,
6323
entries: Optional[Dict] = None,
24+
method: str = ROTATE_UPDATE,
6425
mode: bm.Mode = None,
6526
name: str = None,
66-
method: str = None,
6727
):
68-
super().__init__(mode=mode, name=name)
69-
70-
# delay updating method
28+
DynamicalSystem.__init__(self, mode=mode)
7129
if method is None:
7230
if self.mode.is_a(bm.NonBatchingMode):
7331
method = ROTATE_UPDATE
74-
else:
32+
elif self.mode.is_parent_of(bm.TrainingMode):
7533
method = CONCAT_UPDATE
76-
assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
77-
self.method = method
78-
79-
# target
80-
self.target = target
81-
if not isinstance(target, bm.Variable):
82-
raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}')
83-
84-
# delay length
85-
self.length = is_integer(length, allow_none=False, min_bound=0)
86-
87-
# delay data
88-
if initial_delay_data is not None:
89-
assert isinstance(initial_delay_data, (int, float, bool, bm.Array, jax.Array, Callable))
90-
self._initial_delay_data = initial_delay_data
91-
if length > 0:
92-
self._init_data(length)
93-
else:
94-
self.data = None
95-
96-
# time variables
97-
if self.method == ROTATE_UPDATE:
98-
self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32)))
99-
100-
# other info
101-
self._access_to_step = dict()
102-
for entry, value in entries.items():
103-
self.register_entry(entry, value)
104-
105-
def register_entry(
106-
self,
107-
entry: str,
108-
delay_time: Optional[Union[float, bm.Array, Callable]] = None,
109-
delay_step: Optional[Union[int, bm.Array, Callable]] = None,
110-
) -> 'Delay':
111-
"""Register an entry to access the data.
112-
113-
Args:
114-
entry (str): The entry to access the delay data.
115-
delay_step: The delay step of the entry (must be an integer, denoting the delay step).
116-
delay_time: The delay time of the entry (can be a float).
117-
118-
Returns:
119-
Return the self.
120-
"""
121-
if entry in self._access_to_step:
122-
raise KeyError(f'Entry {entry} has been registered.')
123-
124-
if delay_time is not None:
125-
if delay_step is not None:
126-
raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.')
127-
if callable(delay_time):
128-
delay_time = bm.as_jax(delay_time(self.delay_target_shape))
129-
delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int())
130-
elif isinstance(delay_time, float):
131-
delay_step = int(delay_time / bm.get_dt())
13234
else:
133-
delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int())
134-
135-
# delay steps
136-
if delay_step is None:
137-
delay_type = 'none'
138-
elif isinstance(delay_step, int):
139-
delay_type = 'homo'
140-
elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)):
141-
if delay_step.size == 1 and delay_step.ndim == 0:
142-
delay_type = 'homo'
143-
else:
144-
delay_type = 'heter'
145-
delay_step = bm.Array(delay_step)
146-
elif callable(delay_step):
147-
delay_step = delay_step(self.delay_target_shape)
148-
delay_type = 'heter'
149-
else:
150-
raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
151-
f'integer, array of integers, callable function, brainpy.init.Initializer.')
152-
if delay_type == 'heter':
153-
if delay_step.dtype not in [jnp.int32, jnp.int64]:
154-
raise ValueError('Only support delay steps of int32, int64. If your '
155-
'provide delay time length, please divide the "dt" '
156-
'then provide us the number of delay steps.')
157-
if self.delay_target_shape[0] != delay_step.shape[0]:
158-
raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}')
159-
if delay_type == 'heter':
160-
max_delay_step = int(max(delay_step))
161-
elif delay_type == 'homo':
162-
max_delay_step = delay_step
163-
else:
164-
max_delay_step = None
165-
166-
# delay variable
167-
if max_delay_step is not None:
168-
if self.length < max_delay_step:
169-
self._init_data(max_delay_step)
170-
self.length = max_delay_step
171-
self._access_to_step[entry] = delay_step
172-
return self
173-
174-
def at_entry(self, entry: str, *indices) -> bm.Array:
175-
"""Get the data at the given entry.
176-
177-
Args:
178-
entry (str): The entry to access the data.
179-
*indices:
180-
181-
Returns:
182-
The data.
183-
"""
184-
assert isinstance(entry, str)
185-
if entry not in self._access_to_step:
186-
raise KeyError(f'Does not find delay entry "{entry}".')
187-
delay_step = self._access_to_step[entry]
188-
if delay_step is None:
189-
return self.target.value
190-
else:
191-
if self.data is None:
192-
return self.target.value
193-
else:
194-
if isinstance(delay_step, slice):
195-
return self.retrieve(delay_step, *indices)
196-
elif np.ndim(delay_step) == 0:
197-
return self.retrieve(delay_step, *indices)
198-
else:
199-
if len(indices) == 0 and len(delay_step) == self.target.shape[0]:
200-
indices = (jnp.arange(delay_step.size),)
201-
return self.retrieve(delay_step, *indices)
202-
203-
@property
204-
def delay_target_shape(self):
205-
"""The data shape of the delay target."""
206-
return self.target.shape
207-
208-
def __repr__(self):
209-
name = self.__class__.__name__
210-
return (f'{name}(num_delay_step={self.length}, '
211-
f'delay_target_shape={self.delay_target_shape}, '
212-
f'update_method={self.method})')
213-
214-
def _check_delay(self, delay_len):
215-
raise ValueError(f'The request delay length should be less than the '
216-
f'maximum delay {self.length}. '
217-
f'But we got {delay_len}')
218-
219-
def retrieve(self, delay_step, *indices):
220-
"""Retrieve the delay data according to the delay length.
221-
222-
Parameters
223-
----------
224-
delay_step: int, ArrayType
225-
The delay length used to retrieve the data.
226-
"""
227-
assert delay_step is not None
228-
if check.is_checking():
229-
jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step)
230-
231-
if self.method == ROTATE_UPDATE:
232-
delay_idx = (self.idx.value + delay_step) % (self.length + 1)
233-
delay_idx = stop_gradient(delay_idx)
234-
235-
elif self.method == CONCAT_UPDATE:
236-
delay_idx = delay_step
237-
238-
else:
239-
raise ValueError(f'Unknown updating method "{self.method}"')
240-
241-
# the delay index
242-
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
243-
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
244-
indices = (delay_idx,) + tuple(indices)
245-
246-
# the delay data
247-
return self.data[indices]
35+
method = ROTATE_UPDATE
36+
DelayVariable.__init__(self,
37+
target=target,
38+
length=length,
39+
before_t0=before_t0,
40+
entries=entries,
41+
method=method,
42+
name=name)
24843

24944
@not_pass_shargs
250-
def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None:
251-
"""Update delay variable with the new data.
252-
"""
253-
if self.data is not None:
254-
# get the latest target value
255-
if latest_value is None:
256-
latest_value = self.target.value
257-
258-
# update the delay data at the rotation index
259-
if self.method == ROTATE_UPDATE:
260-
self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1)))
261-
self.data[self.idx.value] = latest_value
262-
263-
# update the delay data at the first position
264-
elif self.method == CONCAT_UPDATE:
265-
if self.length >= 2:
266-
self.data.value = bm.vstack([latest_value, self.data[1:]])
267-
else:
268-
self.data[0] = latest_value
269-
270-
def reset_state(self, batch_size: int = None):
271-
"""Reset the delay data.
272-
"""
273-
# initialize delay data
274-
if self.data is not None:
275-
self._init_data(self.length, batch_size)
276-
277-
# time variables
278-
if self.method == ROTATE_UPDATE:
279-
self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32))
280-
281-
def _init_data(self, length, batch_size: int = None):
282-
if batch_size is not None:
283-
if self.target.batch_size != batch_size:
284-
raise ValueError(f'The batch sizes of delay variable and target variable differ '
285-
f'({self.target.batch_size} != {batch_size}). '
286-
'Please reset the target variable first, because delay data '
287-
'depends on the target variable. ')
45+
def update(self, *args, **kwargs):
46+
return DelayVariable.update(self, *args, **kwargs)
28847

289-
if self.target.batch_axis is None:
290-
batch_axis = None
291-
else:
292-
batch_axis = self.target.batch_axis + 1
293-
self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype),
294-
batch_axis=batch_axis)
295-
# update delay data
296-
self.data[0] = self.target.value
297-
if isinstance(self._initial_delay_data, (bm.Array, jax.Array, float, int, bool)):
298-
self.data[1:] = self._initial_delay_data
299-
elif callable(self._initial_delay_data):
300-
self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype)

brainpy/_src/experimental/neurons.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class LIF(NeuGroup):
5252
Refractory period length.(ms)
5353
V_initializer: ArrayType, Initializer, callable
5454
The initializer of membrane potential.
55-
noise: ArrayType, Initializer, callable
56-
The noise added onto the membrane potential
5755
method: str
5856
The numerical integration method.
5957
name: str
@@ -125,7 +123,7 @@ def reset_state(self, batch_size=None):
125123

126124
@not_pass_shargs
127125
def update(self, current):
128-
t = bm.share.get('t')
126+
t = bm.share.load('t')
129127

130128
# integrate membrane potential
131129
V = self.integral(self.V.value, t, current, bm.dt)

brainpy/_src/experimental/synapses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def update(self, pre_spike):
255255
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
256256

257257
# updates
258-
self.g.value = self.integral(self.g.value, bm.share.get('t'), bm.dt) + post_vs
258+
self.g.value = self.integral(self.g.value, bm.share.load('t'), bm.dt) + post_vs
259259

260260
# outputs
261261
if self.out is not None:

0 commit comments

Comments
 (0)