Skip to content

Commit 0edfadc

Browse files
authored
Merge pull request #245 from chaoming0625/master
update apis and examples
2 parents b65b766 + 41c3ef9 commit 0edfadc

File tree

14 files changed

+95
-76
lines changed

14 files changed

+95
-76
lines changed

.github/workflows/Sync_branches.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ jobs:
99
steps:
1010
- uses: actions/checkout@master
1111

12-
- name: Merge master -> brainpy-2.x
12+
- name: Merge master -> brainpy-2.2.x
1313
uses: devmasx/merge-branch@master
1414
with:
1515
type: now
1616
from_branch: master
17-
target_branch: brainpy-2.x
17+
target_branch: brainpy-2.2.x
1818
github_token: ${{ github.token }}

brainpy/base/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class Base(object):
2929
3030
"""
3131

32+
_excluded_vars = ()
33+
3234
def __init__(self, name=None):
3335
# check whether the object has a unique name.
3436
self._name = None
@@ -120,8 +122,10 @@ def vars(self, method='absolute', level=-1, include_self=True):
120122
for node_path, node in nodes.items():
121123
for k in dir(node):
122124
v = getattr(node, k)
123-
if isinstance(v, math.Variable) and not k.startswith('_') and not k.endswith('_'):
124-
gather[f'{node_path}.{k}' if node_path else k] = v
125+
if isinstance(v, math.Variable):
126+
if k not in node._excluded_vars:
127+
# if not k.startswith('_') and not k.endswith('_'):
128+
gather[f'{node_path}.{k}' if node_path else k] = v
125129
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
126130
return gather
127131

brainpy/dyn/neurons/fractional_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226
self,
227227
size: Shape,
228228
alpha: Union[float, Sequence[float]],
229-
num_step: int,
229+
num_memory: int,
230230
a: Union[float, Tensor, Initializer, Callable] = 0.02,
231231
b: Union[float, Tensor, Initializer, Callable] = 0.20,
232232
c: Union[float, Tensor, Initializer, Callable] = -65.,
@@ -272,10 +272,10 @@ def __init__(
272272
self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))
273273

274274
# functions
275-
check_integer(num_step, 'num_step', allow_none=False)
275+
check_integer(num_memory, 'num_step', allow_none=False)
276276
self.integral = CaputoL1Schema(f=self.derivative,
277277
alpha=alpha,
278-
num_memory=num_step,
278+
num_memory=num_memory,
279279
inits=[self.V, self.u])
280280

281281
def reset_state(self, batch_size=None):

brainpy/inputs/currents.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ def constant_input(I_and_duration, dt=None):
113113

114114
# get the current
115115
start = 0
116-
I_current = jnp.zeros((int(np.ceil(I_duration / dt)),) + I_shape)
116+
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape)
117117
for c_size, duration in I_and_duration:
118118
length = int(duration / dt)
119-
I_current = I_current.at[start: start + length].set(c_size)
119+
I_current[start: start + length] = c_size
120120
start += length
121-
return I_current, I_duration
121+
return I_current.value, I_duration
122122

123123

124124
def constant_current(*args, **kwargs):
@@ -172,12 +172,12 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None):
172172
if isinstance(sp_sizes, (float, int)):
173173
sp_sizes = [sp_sizes] * len(sp_times)
174174

175-
current = jnp.zeros(int(np.ceil(duration / dt)))
175+
current = bm.zeros(int(np.ceil(duration / dt)))
176176
for time, dur, size in zip(sp_times, sp_lens, sp_sizes):
177177
pp = int(time / dt)
178178
p_len = int(dur / dt)
179-
current = current.at[pp: pp + p_len].set(size)
180-
return current
179+
current[pp: pp + p_len] = size
180+
return current.value
181181

182182

183183
def spike_current(*args, **kwargs):
@@ -218,12 +218,12 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
218218
dt = bm.get_dt() if dt is None else dt
219219
t_end = duration if t_end is None else t_end
220220

221-
current = jnp.zeros(int(np.ceil(duration / dt)))
221+
current = bm.zeros(int(np.ceil(duration / dt)))
222222
p1 = int(np.ceil(t_start / dt))
223223
p2 = int(np.ceil(t_end / dt))
224224
cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1))
225-
current = current.at[p1: p2].set(cc)
226-
return current
225+
current[p1: p2] = cc
226+
return current.value
227227

228228

229229
def ramp_current(*args, **kwargs):
@@ -265,9 +265,9 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
265265
i_start = int(t_start / dt)
266266
i_end = int(t_end / dt)
267267
noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt)
268-
currents = jnp.zeros((int(duration / dt), n))
269-
currents = currents.at[i_start: i_end].set(bm.as_device_array(noises))
270-
return currents
268+
currents = bm.zeros((int(duration / dt), n))
269+
currents[i_start: i_end] = noises
270+
return currents.value
271271

272272

273273
def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None):

brainpy/integrators/fde/Caputo.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154

155155
def _check_step(self, args):
156156
dt, t = args
157-
raise ValueError(f'The maximum number of step is {self.num_step}, '
157+
raise ValueError(f'The maximum number of step is {self.num_memory}, '
158158
f'however, the current time {t} require a time '
159159
f'step number {t / dt}.')
160160

@@ -164,7 +164,7 @@ def _integral_func(self, *args, **kwargs):
164164
t = all_args['t']
165165
dt = all_args.pop(DT, self.dt)
166166
if check.is_checking():
167-
check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t))
167+
check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t))
168168

169169
# derivative values
170170
devs = self.f(**all_args)
@@ -185,11 +185,11 @@ def _integral_func(self, *args, **kwargs):
185185

186186
# integral results
187187
integrals = []
188-
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
188+
idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory
189189
for i, key in enumerate(self.variables):
190190
integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key]
191191
integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i]))
192-
self.idx.value = (self.idx + 1) % self.num_step
192+
self.idx.value = (self.idx + 1) % self.num_memory
193193

194194
# return integrals
195195
if len(self.variables) == 1:
@@ -344,19 +344,19 @@ def __init__(
344344
dtype=self.inits[v].dtype))
345345
for v in self.variables}
346346
self.register_implicit_vars(self.diff_states)
347-
self.idx = bm.Variable(bm.asarray([self.num_step - 1]))
347+
self.idx = bm.Variable(bm.asarray([self.num_memory - 1]))
348348

349349
# integral function
350350
self.set_integral(self._integral_func)
351351

352352
def reset(self, inits):
353353
"""Reset function."""
354-
self.idx.value = bm.asarray([self.num_step - 1])
354+
self.idx.value = bm.asarray([self.num_memory - 1])
355355
inits = check_inits(inits, self.variables)
356356
for key, value in inits.items():
357357
self.inits[key].value = value
358358
for key, val in inits.items():
359-
self.diff_states[key + "_diff"].value = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype)
359+
self.diff_states[key + "_diff"].value = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)
360360

361361
def hists(self, var=None, numpy=True):
362362
"""Get the recorded history values."""
@@ -378,7 +378,7 @@ def hists(self, var=None, numpy=True):
378378

379379
def _check_step(self, args):
380380
dt, t = args
381-
raise ValueError(f'The maximum number of step is {self.num_step}, '
381+
raise ValueError(f'The maximum number of step is {self.num_memory}, '
382382
f'however, the current time {t} require a time '
383383
f'step number {t / dt}.')
384384

@@ -388,7 +388,7 @@ def _integral_func(self, *args, **kwargs):
388388
t = all_args['t']
389389
dt = all_args.pop(DT, self.dt)
390390
if check.is_checking():
391-
check_error_in_jit(self.num_step * dt < t, self._check_step, (dt, t))
391+
check_error_in_jit(self.num_memory * dt < t, self._check_step, (dt, t))
392392

393393
# derivative values
394394
devs = self.f(**all_args)
@@ -405,15 +405,15 @@ def _integral_func(self, *args, **kwargs):
405405

406406
# integral results
407407
integrals = []
408-
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step
408+
idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory
409409
for i, key in enumerate(self.variables):
410410
self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key]
411411
self.inits[key].value = all_args[key]
412412
markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key]
413413
memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff']
414414
integral = markov_term - memory_trace
415415
integrals.append(integral)
416-
self.idx.value = (self.idx + 1) % self.num_step
416+
self.idx.value = (self.idx + 1) % self.num_memory
417417

418418
# return integrals
419419
if len(self.variables) == 1:

brainpy/integrators/fde/GL.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This module provides numerical solvers for Grünwald–Letnikov derivative FDEs.
55
"""
66

7-
from typing import Dict, Union, Callable
7+
from typing import Dict, Union, Callable, Any
88

99
import jax.numpy as jnp
1010

@@ -127,8 +127,8 @@ class GLShortMemory(FDEIntegrator):
127127
def __init__(
128128
self,
129129
f: Callable,
130-
alpha,
131-
inits,
130+
alpha: Any,
131+
inits: Any,
132132
num_memory: int,
133133
dt: float = None,
134134
name: str = None,
@@ -152,9 +152,9 @@ def __init__(
152152
# delays
153153
self.delays = {}
154154
for key, val in inits.items():
155-
delay = bm.Variable(bm.zeros((self.num_step,) + val.shape, dtype=val.dtype))
155+
delay = bm.Variable(bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype))
156156
delay[0] = val
157-
self.delays[key] = delay
157+
self.delays[key+'_delay'] = delay
158158
self._idx = bm.Variable(bm.asarray([1]))
159159
self.register_implicit_vars(self.delays)
160160

@@ -171,7 +171,7 @@ def reset(self, inits):
171171
self._idx.value = bm.asarray([1])
172172
inits = check_inits(inits, self.variables)
173173
for key, val in inits.items():
174-
delay = bm.zeros((self.num_step,) + val.shape, dtype=val.dtype)
174+
delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype)
175175
delay[0] = val
176176
self.delays[key].value = delay
177177

@@ -199,13 +199,14 @@ def _integral_func(self, *args, **kwargs):
199199

200200
# integral results
201201
integrals = []
202-
idx = (self._idx + bm.arange(self.num_step)) % self.num_step
202+
idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory
203203
for i, var in enumerate(self.variables):
204-
summation = self._binomial_coef[:, i] @ self.delays[var][idx]
204+
delay_var = var + '_delay'
205+
summation = self._binomial_coef[:, i] @ self.delays[delay_var][idx]
205206
integral = (dt ** self.alpha[i]) * devs[var] - summation
206-
self.delays[var][self._idx[0]] = integral
207+
self.delays[delay_var][self._idx[0]] = integral
207208
integrals.append(integral)
208-
self._idx.value = (self._idx + 1) % self.num_step
209+
self._idx.value = (self._idx + 1) % self.num_memory
209210

210211
# return integrals
211212
if len(self.variables) == 1:

brainpy/integrators/fde/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
arguments = parses[2] # function arguments
5656

5757
# memory length
58-
check_integer(num_memory, 'num_step', allow_none=False, min_bound=1)
59-
self.num_step = num_memory
58+
check_integer(num_memory, 'num_memory', allow_none=False, min_bound=1)
59+
self.num_memory = num_memory
6060

6161
# super initialization
6262
super(FDEIntegrator, self).__init__(name=name,

brainpy/integrators/fde/tests/test_Caputo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test1(self):
2525

2626
intg.idx[0] = N - 1
2727
intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff)
28-
idx = ((intg.num_step - intg.idx) + np.arange(intg.num_step)) % intg.num_step
28+
idx = ((intg.num_memory - intg.idx) + np.arange(intg.num_memory)) % intg.num_memory
2929
memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff']
3030

3131
print()

brainpy/integrators/runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,11 @@ def __init__(
134134
numpy_mon_after_run: bool
135135
"""
136136

137-
# initialize variables
138137
if not isinstance(target, Integrator):
139138
raise TypeError(f'Target must be instance of {Integrator.__name__}, '
140139
f'but we got {type(target)}')
140+
141+
# get maximum size and initial variables
141142
if inits is not None:
142143
if isinstance(inits, (list, tuple, bm.JaxArray, jnp.ndarray)):
143144
assert len(target.variables) == len(inits)
@@ -148,6 +149,8 @@ def __init__(
148149
else:
149150
max_size = 1
150151
inits = dict()
152+
153+
# initialize variables
151154
self.variables = TensorCollector({v: bm.Variable(bm.zeros(max_size))
152155
for v in target.variables})
153156
for k in inits.keys():
@@ -207,7 +210,6 @@ def __init__(
207210
self.dyn_vars.update(self.target.vars().unique())
208211

209212
# Variables
210-
211213
self.dyn_vars.update(self.variables)
212214
if len(self._dyn_args) > 0:
213215
self.idx = bm.Variable(bm.zeros(1, dtype=jnp.int_))
@@ -240,11 +242,6 @@ def _loop_func(times):
240242
return out_vars, returns
241243
self.step_func = _loop_func
242244

243-
def _post(self, times, returns: dict): # monitor
244-
self.mon.ts = times + self.dt
245-
for key in returns.keys():
246-
self.mon[key] = bm.asarray(returns[key])
247-
248245
def _step(self, t):
249246
# arguments
250247
kwargs = dict()
@@ -254,17 +251,21 @@ def _step(self, t):
254251
if len(self._dyn_args) > 0:
255252
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
256253
self.idx += 1
254+
257255
# return of function monitors
258256
returns = dict()
259257
for key, func in self.fun_monitors.items():
260258
returns[key] = func(t, self.dt)
259+
261260
# call integrator function
262261
update_values = self.target(**kwargs)
263262
if len(self.target.variables) == 1:
264263
self.variables[self.target.variables[0]].update(update_values)
265264
else:
266265
for i, v in enumerate(self.target.variables):
267266
self.variables[v].update(update_values[i])
267+
268+
# progress bar
268269
if self.progress_bar:
269270
id_tap(lambda *args: self._pbar.update(), ())
270271
return returns

brainpy/integrators/sde/normal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def step(self, *args, **kwargs):
114114
# diffusion values
115115
diffusions = self.g(**all_args)
116116
if len(self.variables) == 1:
117-
if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
118-
raise ValueError('Diffusion values must be a tensor when there '
119-
'is only one variable in the equation.')
117+
# if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)):
118+
# raise ValueError('Diffusion values must be a tensor when there '
119+
# 'is only one variable in the equation.')
120120
diffusions = {self.variables[0]: diffusions}
121121
else:
122122
if not isinstance(diffusions, (tuple, list)):

0 commit comments

Comments
 (0)