Skip to content

Commit b65b766

Browse files
authored
Merge pull request #244 from chaoming0625/master
update quickstart docs & enable jit error checking
2 parents cc2cd73 + 8ba66a2 commit b65b766

File tree

17 files changed

+2055
-1655
lines changed

17 files changed

+2055
-1655
lines changed

brainpy/dyn/synapses/delay_couplings.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from jax import vmap
77

88
import brainpy.math as bm
9-
from brainpy.dyn.base import DynamicalSystem
9+
from brainpy.dyn.base import SynConn, SynOut
10+
from brainpy.dyn.synouts import CUBA
1011
from brainpy.initialize import Initializer
12+
from brainpy.dyn.neurons.input_groups import InputGroup, OutputGroup
1113
from brainpy.modes import Mode, TrainingMode, normal
1214
from brainpy.tools.checking import check_sequence
1315
from brainpy.types import Tensor
@@ -19,7 +21,7 @@
1921
]
2022

2123

22-
class DelayCoupling(DynamicalSystem):
24+
class DelayCoupling(SynConn):
2325
"""Delay coupling.
2426
2527
Parameters
@@ -49,7 +51,10 @@ def __init__(
4951
name: str = None,
5052
mode: Mode = normal,
5153
):
52-
super(DelayCoupling, self).__init__(name=name, mode=mode)
54+
super(DelayCoupling, self).__init__(name=name,
55+
mode=mode,
56+
pre=InputGroup(1),
57+
post=OutputGroup(1))
5358

5459
# delay variable
5560
if not isinstance(delay_var, bm.Variable):
@@ -201,8 +206,8 @@ def update(self, tdi):
201206
indices = (slice(None, None, None), bm.arange(self.coupling_var1.size),)
202207
else:
203208
indices = (bm.arange(self.coupling_var1.size),)
204-
f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (..., pre.num)
205-
delays = f(bm.arange(self.coupling_var2.size).value) # (..., post.num, pre.num)
209+
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num)
210+
delays = f(self.delay_steps) # (..., post.num, pre.num)
206211
diffusive = (bm.moveaxis(delays, axis - 1, axis) -
207212
bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num)
208213
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
@@ -284,8 +289,8 @@ def update(self, tdi):
284289
indices = (slice(None, None, None), bm.arange(self.coupling_var.size),)
285290
else:
286291
indices = (bm.arange(self.coupling_var.size),)
287-
f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (.., pre.num,)
288-
delays = f(bm.arange(self.coupling_var.size).value) # (..., post.num, pre.num)
292+
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,)
293+
delays = f(self.delay_steps) # (..., post.num, pre.num)
289294
additive = (self.conn_mat * bm.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1)
290295
elif self.delay_type == 'int':
291296
delayed_var = delay_var(self.delay_steps) # (..., pre.num)

brainpy/initialize/generic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,29 @@ def variable(
9898
return bm.Variable(data(new_shape), batch_axis=batch_axis)
9999
elif batch_size_or_mode in (None, False):
100100
return bm.Variable(data(var_shape))
101-
else:
101+
elif isinstance(batch_size_or_mode, int):
102102
new_shape = var_shape[:batch_axis] + (int(batch_size_or_mode),) + var_shape[batch_axis:]
103103
return bm.Variable(data(new_shape), batch_axis=batch_axis)
104+
else:
105+
raise ValueError('Unknown batch_size_or_mode.')
106+
104107
else:
105108
if var_shape is not None:
106109
if bm.shape(data) != var_shape:
107110
raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {var_shape}')
108111
if isinstance(batch_size_or_mode, NormalMode):
109-
return bm.Variable(data(var_shape))
112+
return bm.Variable(data)
110113
elif isinstance(batch_size_or_mode, BatchingMode):
111114
return bm.Variable(bm.expand_dims(data, axis=batch_axis), batch_axis=batch_axis)
112115
elif batch_size_or_mode in (None, False):
113116
return bm.Variable(data)
114-
else:
117+
elif isinstance(batch_size_or_mode, int):
115118
return bm.Variable(bm.repeat(bm.expand_dims(data, axis=batch_axis),
116119
int(batch_size_or_mode),
117120
axis=batch_axis),
118121
batch_axis=batch_axis)
122+
else:
123+
raise ValueError('Unknown batch_size_or_mode.')
119124

120125

121126
def noise(

brainpy/math/delayvars.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def reset(
327327
# delay data
328328
if self.data is None:
329329
if batch_axis is None:
330-
if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None):
330+
if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None):
331331
batch_axis = delay_target.batch_axis + 1
332332
self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape,
333333
dtype=delay_target.dtype),
@@ -348,7 +348,8 @@ def reset(
348348

349349
def _check_delay(self, delay_len):
350350
raise ValueError(f'The request delay length should be less than the '
351-
f'maximum delay {self.num_delay_step}. But we got {delay_len}')
351+
f'maximum delay {self.num_delay_step}. '
352+
f'But we got {delay_len}')
352353

353354
def __call__(self, delay_len, *indices):
354355
# check

brainpy/math/remove_vmap.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from brainpy.math.numpy_ops import any, all
4+
from jax.core import Primitive
5+
from jax.interpreters import batching, mlir, xla
6+
from jax.abstract_arrays import ShapedArray
7+
import jax.numpy as jnp
8+
9+
10+
__all__ = [
11+
'remove_vmap'
12+
]
13+
14+
15+
def remove_vmap(x, op='any'):
16+
if op == 'any':
17+
return _any_without_vmap(x)
18+
elif op == 'all':
19+
return _all_without_vmap(x)
20+
else:
21+
raise ValueError(f'Do not support type: {op}')
22+
23+
24+
_any_no_vmap_prim = Primitive('any_no_vmap')
25+
26+
27+
def _any_without_vmap(x):
28+
return _any_no_vmap_prim.bind(x)
29+
30+
31+
def _any_without_vmap_imp(x):
32+
return any(x)
33+
34+
35+
def _any_without_vmap_abs(x):
36+
return ShapedArray(shape=(), dtype=jnp.bool_)
37+
38+
39+
def _any_without_vmap_batch(x, batch_axes):
40+
(x, ) = x
41+
return _any_without_vmap(x), batching.not_mapped
42+
43+
44+
_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
45+
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
46+
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
47+
if hasattr(xla, "lower_fun"):
48+
xla.register_translation(_any_no_vmap_prim,
49+
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
50+
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
51+
52+
53+
_all_no_vmap_prim = Primitive('all_no_vmap')
54+
55+
56+
def _all_without_vmap(x):
57+
return _all_no_vmap_prim.bind(x)
58+
59+
60+
def _all_without_vmap_imp(x):
61+
return all(x)
62+
63+
64+
def _all_without_vmap_abs(x):
65+
return ShapedArray(shape=(), dtype=jnp.bool_)
66+
67+
68+
def _all_without_vmap_batch(x, batch_axes):
69+
(x, ) = x
70+
return _all_without_vmap(x), batching.not_mapped
71+
72+
73+
_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
74+
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
75+
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
76+
if hasattr(xla, "lower_fun"):
77+
xla.register_translation(_all_no_vmap_prim,
78+
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
79+
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
80+

brainpy/tools/errors.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,22 @@
99
]
1010

1111

12-
def _make_err_func(f):
13-
f2 = lambda arg, transforms: f(arg)
14-
15-
def err_f(x):
16-
id_tap(f2, x)
17-
return
18-
return err_f
19-
20-
21-
def check_error_in_jit(pred, err_f, err_arg=None):
12+
def check_error_in_jit(pred, err_fun, err_arg=None):
2213
"""Check errors in a jit function.
2314
2415
Parameters
2516
----------
2617
pred: bool
2718
The boolean prediction.
28-
err_f: callable
19+
err_fun: callable
2920
The error function, which raise errors.
3021
err_arg: any
3122
The arguments which passed into `err_f`.
3223
"""
33-
cond(pred, _make_err_func(err_f), lambda _: None, err_arg)
24+
from brainpy.math.remove_vmap import remove_vmap
3425

26+
def err_f(x):
27+
id_tap(lambda arg, transforms: err_fun(arg), x)
28+
return
29+
cond(remove_vmap(pred), err_f, lambda _: None, err_arg)
3530

docs/auto_generater.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def generate_datasets_docs(path='apis/auto/datasets/'):
257257
header='Chaotic Systems')
258258
write_module(module_name='brainpy.datasets.vision',
259259
filename=os.path.join(path, 'vision.rst'),
260-
header='Chaotic Systems')
260+
header='Vision Datasets')
261261

262262

263263
def generate_dyn_docs(path='apis/auto/dyn/'):

docs/index.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ The code of BrainPy is open-sourced at GitHub:
3838

3939
quickstart/installation
4040
quickstart/simulation
41-
quickstart/rate_model
4241
quickstart/training
4342
quickstart/analysis
4443

@@ -66,9 +65,6 @@ The code of BrainPy is open-sourced at GitHub:
6665
tutorial_toolbox/synaptic_connections
6766
tutorial_toolbox/synaptic_weights
6867
tutorial_toolbox/optimizers
69-
tutorial_toolbox/runners
70-
tutorial_toolbox/inputs
71-
tutorial_toolbox/monitors
7268
tutorial_toolbox/saving_and_loading
7369

7470

0 commit comments

Comments
 (0)