Skip to content

Commit b4ef718

Browse files
authored
Support initializing a Variable by data shape (#265)
Support initializing a Variable by data shape
2 parents b8691ae + 359dcbe commit b4ef718

File tree

12 files changed

+271
-78
lines changed

12 files changed

+271
-78
lines changed

brainpy/dyn/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,8 @@ def __init__(
756756

757757
def __repr__(self):
758758
names = self.__class__.__name__
759-
return (f'{names}(name={self.name}, mode={self.mode}, '
760-
f'{" " * len(names)} pre={self.pre}, '
759+
return (f'{names}(name={self.name}, mode={self.mode}, \n'
760+
f'{" " * len(names)} pre={self.pre}, \n'
761761
f'{" " * len(names)} post={self.post})')
762762

763763
def check_pre_attrs(self, *attrs):

brainpy/dyn/layers/dropout.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class Dropout(DynamicalSystem):
1515
In training, to compensate for the fraction of input values dropped (`rate`),
1616
all surviving values are multiplied by `1 / (1 - rate)`.
1717
18-
The parameter `shared_axes` allows to specify a list of axes on which
19-
the mask will be shared: we will use size 1 on those axes for dropout mask
20-
and broadcast it. Sharing reduces randomness, but can save memory.
21-
2218
This layer is active only during training (`mode='train'`). In other
2319
circumstances it is a no-op.
2420

brainpy/dyn/synapses/gap_junction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def __init__(
2929
conn=conn,
3030
name=name)
3131
# checking
32-
self.check_pre_attrs('V', 'spike')
33-
self.check_post_attrs('V', 'input', 'spike')
32+
self.check_pre_attrs('V')
33+
self.check_post_attrs('V', 'input')
3434

3535
# assert isinstance(self.output, _NullSynOut)
3636
# assert isinstance(self.stp, _NullSynSTP)

brainpy/math/jaxarray.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -885,10 +885,42 @@ def __jax_array__(self):
885885

886886
class Variable(JaxArray):
887887
"""The pointer to specify the dynamical variable.
888+
889+
Initializing an instance of ``Variable`` by two ways:
890+
891+
>>> import brainpy.math as bm
892+
>>> # 1. init a Variable by the concreate data
893+
>>> v1 = bm.Variable(bm.zeros(10))
894+
>>> # 2. init a Variable by the data shape
895+
>>> v2 = bm.Variable(10)
896+
897+
Note that when initializing a `Variable` by the data shape,
898+
all values in this `Variable` will be initialized as zeros.
899+
900+
Parameters
901+
----------
902+
value_or_size: Shape, Array
903+
The value or the size of the value.
904+
dtype:
905+
The type of the data.
906+
batch_axis: optional, int
907+
The batch axis.
888908
"""
889909
__slots__ = ('_value', '_batch_axis')
890910

891-
def __init__(self, value, dtype=None, batch_axis: int = None):
911+
def __init__(
912+
self,
913+
value_or_size,
914+
dtype=None,
915+
batch_axis: int = None
916+
):
917+
if isinstance(value_or_size, int):
918+
value = jnp.zeros(value_or_size, dtype=dtype)
919+
elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]):
920+
value = jnp.zeros(value_or_size, dtype=dtype)
921+
else:
922+
value = value_or_size
923+
892924
super(Variable, self).__init__(value, dtype=dtype)
893925

894926
# check batch axis
@@ -1464,17 +1496,17 @@ class TrainVar(Variable):
14641496
"""
14651497
__slots__ = ('_value', '_batch_axis')
14661498

1467-
def __init__(self, value, dtype=None, batch_axis: int = None):
1468-
super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
1499+
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
1500+
super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)
14691501

14701502

14711503
class Parameter(Variable):
14721504
"""The pointer to specify the parameter.
14731505
"""
14741506
__slots__ = ('_value', '_batch_axis')
14751507

1476-
def __init__(self, value, dtype=None, batch_axis: int = None):
1477-
super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
1508+
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
1509+
super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)
14781510

14791511

14801512
register_pytree_node(JaxArray,

brainpy/math/tests/test_jaxarray.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ def test_none(self):
4040
ee = a + e
4141

4242

43+
class TestVariable(unittest.TestCase):
44+
def test_variable_init(self):
45+
self.assertTrue(
46+
bm.array_equal(bm.Variable(bm.zeros(10)),
47+
bm.Variable(10))
48+
)
49+
bm.random.seed(123)
50+
self.assertTrue(
51+
not bm.array_equal(bm.Variable(bm.random.rand(10)),
52+
bm.Variable(10))
53+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
""""
5+
Implementation of the paper:
6+
7+
- Fazli, Mehran, and Richard Bertram. "Network Properties of Electrically
8+
Coupled Bursting Pituitary Cells." Frontiers in Endocrinology 13 (2022).
9+
"""
10+
11+
import brainpy as bp
12+
import brainpy.math as bm
13+
14+
15+
class PituitaryCell(bp.NeuGroup):
16+
def __init__(self, size, name=None):
17+
super(PituitaryCell, self).__init__(size, name=name)
18+
19+
# parameter values
20+
self.vn = -5
21+
self.kc = 0.12
22+
self.ff = 0.005
23+
self.vca = 60
24+
self.vk = -75
25+
self.vl = -50.0
26+
self.gk = 2.5
27+
self.cm = 5
28+
self.gbk = 1
29+
self.gca = 2.1
30+
self.gsk = 2
31+
self.vm = -20
32+
self.vb = -5
33+
self.sn = 10
34+
self.sm = 12
35+
self.sbk = 2
36+
self.taun = 30
37+
self.taubk = 5
38+
self.ks = 0.4
39+
self.alpha = 0.0015
40+
self.gl = 0.2
41+
42+
# variables
43+
self.V = bm.Variable(bm.random.random(self.num) * -90 + 20)
44+
self.n = bm.Variable(bm.random.random(self.num) / 2)
45+
self.b = bm.Variable(bm.random.random(self.num) / 2)
46+
self.c = bm.Variable(bm.random.random(self.num))
47+
self.input = bm.Variable(self.num)
48+
49+
# integrators
50+
self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler')
51+
52+
def dn(self, n, t, V):
53+
ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn))
54+
return (ninf - n) / self.taun
55+
56+
def db(self, b, t, V):
57+
bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk))
58+
return (bkinf - b) / self.taubk
59+
60+
def dc(self, c, t, V):
61+
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
62+
ica = self.gca * minf * (V - self.vca)
63+
return -self.ff * (self.alpha * ica + self.kc * c)
64+
65+
def dV(self, V, t, n, b, c):
66+
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
67+
cinf = c ** 2 / (c ** 2 + self.ks * self.ks)
68+
ica = self.gca * minf * (V - self.vca)
69+
isk = self.gsk * cinf * (V - self.vk)
70+
ibk = self.gbk * b * (V - self.vk)
71+
ikdr = self.gk * n * (V - self.vk)
72+
il = self.gl * (V - self.vl)
73+
return -(ica + isk + ibk + ikdr + il + self.input) / self.cm
74+
75+
def update(self, tdi, x=None):
76+
V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt)
77+
self.V.value = V
78+
self.n.value = n
79+
self.c.value = c
80+
self.b.value = b
81+
82+
def clear_input(self):
83+
self.input.value = bm.zeros_like(self.input)
84+
85+
86+
class PituitaryNetwork(bp.Network):
87+
def __init__(self, num, gc):
88+
super(PituitaryNetwork, self).__init__()
89+
90+
self.N = PituitaryCell(num)
91+
self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc)
92+
93+
94+
if __name__ == '__main__':
95+
net = PituitaryNetwork(2, 0.002)
96+
runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5)
97+
runner.run(10 * 1e3)
98+
99+
fig, gs = bp.visualize.get_figure(1, 1, 6, 10)
100+
fig.add_subplot(gs[0, 0])
101+
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True)

extensions/brainpylib/atomic_sum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def _atomic_sum_translation(c, values, pre_ids, post_ids, *, post_num, platform=
115115
shape_with_layout=x_shape(np.dtype(values_dtype), (post_num,), (0,)),
116116
)
117117
elif platform == 'gpu':
118-
if gpu_ops is None: raise ValueError('Cannot find compiled gpu wheels.')
118+
if gpu_ops is None:
119+
raise ValueError('Cannot find compiled gpu wheels.')
119120

120121
opaque = gpu_ops.build_atomic_sum_descriptor(conn_size, post_num)
121122
if values_dim[0] != 1:

extensions/brainpylib/event_sum.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
from functools import partial
99

10+
from typing import Union, Tuple
1011
import jax.numpy as jnp
1112
import numpy as np
12-
from jax import core
13+
from jax import core, dtypes
1314
from jax.abstract_arrays import ShapedArray
1415
from jax.interpreters import xla, batching
1516
from jax.lax import scan
1617
from jax.lib import xla_client
1718

19+
from .utils import GPUOperatorNotFound
20+
1821
try:
1922
from . import gpu_ops
2023
except ImportError:
@@ -26,7 +29,10 @@
2629
_event_sum_prim = core.Primitive("event_sum")
2730

2831

29-
def event_sum(events, pre2post, post_num, values):
32+
def event_sum(events: jnp.ndarray,
33+
pre2post: Tuple[jnp.ndarray, jnp.ndarray],
34+
post_num: int,
35+
values: Union[float, jnp.ndarray]):
3036
# events
3137
if events.dtype != jnp.bool_:
3238
raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}')
@@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values):
3945
if indices.dtype != indptr.dtype:
4046
raise ValueError(f"The dtype of pre2post[0] must be equal to that of pre2post[1], "
4147
f"while we got {(indices.dtype, indptr.dtype)}")
42-
if indices.dtype not in [jnp.uint32, jnp.uint64]:
43-
raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}')
48+
if indices.dtype not in [jnp.uint32, jnp.uint64, jnp.int32, jnp.int64]:
49+
raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}')
4450

4551
# output value
46-
values = jnp.asarray([values])
47-
if values.dtype not in [jnp.float32, jnp.float64]:
48-
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.')
49-
if values.size not in [1, indices.size]:
52+
dtype = values.dtype if isinstance(values, jnp.ndarray) else dtypes.canonicalize_dtype(type(values))
53+
if dtype not in [jnp.float32, jnp.float64]:
54+
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {dtype}.')
55+
if np.size(values) not in [1, indices.size]:
5056
raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), '
51-
f'while we got {values.size} != 1 != {indices.size}')
52-
values = values.flatten()
57+
f'while we got {np.size(values)} != 1 != {indices.size}')
5358
# bind operator
5459
return _event_sum_prim.bind(events, indices, indptr, values, post_num=post_num)
5560

@@ -58,34 +63,27 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num):
5863
return ShapedArray(dtype=values.dtype, shape=(post_num,))
5964

6065

61-
_event_sum_prim.def_abstract_eval(_event_sum_abstract)
62-
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
63-
64-
6566
def _event_sum_translation(c, events, indices, indptr, values, *, post_num, platform="cpu"):
66-
# The pre/post shape
67+
# The shape of pre/post
6768
pre_size = np.array(c.get_shape(events).dimensions()[0], dtype=np.uint32)
6869
_pre_shape = x_shape(np.dtype(np.uint32), (), ())
6970
_post_shape = x_shape(np.dtype(np.uint32), (), ())
7071

7172
# The indices shape
7273
indices_shape = c.get_shape(indices)
7374
Itype = indices_shape.element_type()
74-
assert Itype in [np.uint32, np.uint64]
7575

7676
# The value shape
7777
values_shape = c.get_shape(values)
7878
Ftype = values_shape.element_type()
79-
assert Ftype in [np.float32, np.float64]
8079
values_dim = values_shape.dimensions()
8180

8281
# We dispatch a different call depending on the dtype
83-
f_type = b'_f32' if Ftype == np.float32 else b'_f64'
84-
i_type = b'_i32' if Itype == np.uint32 else b'_i64'
82+
f_type = b'_f32' if Ftype in np.float32 else b'_f64'
83+
i_type = b'_i32' if Itype in [np.uint32, np.int32] else b'_i64'
8584

86-
# And then the following is what changes between the GPU and CPU
8785
if platform == "cpu":
88-
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
86+
v_type = b'_event_sum_homo' if len(values_dim) == 0 else b'_event_sum_heter'
8987
return x_ops.CustomCallWithLayout(
9088
c,
9189
platform.encode() + v_type + f_type + i_type,
@@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
103101
c.get_shape(values)),
104102
shape_with_layout=x_shape(np.dtype(Ftype), (post_num,), (0,)),
105103
)
104+
105+
# GPU platform
106106
elif platform == 'gpu':
107107
if gpu_ops is None:
108-
raise ValueError('Cannot find compiled gpu wheels.')
108+
raise GPUOperatorNotFound('event_sum')
109+
109110
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
110111
opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_num)
111112
return x_ops.CustomCallWithLayout(
@@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
127128
raise ValueError("Unsupported platform, we only support 'cpu' or 'gpu'")
128129

129130

130-
xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
131-
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")
132-
133-
134-
def _event_sum_batch(args, axes):
131+
def _event_sum_batch(args, axes, *, post_num):
135132
batch_axes, batch_args, non_batch_args = [], {}, {}
136133
for ax_i, ax in enumerate(axes):
137134
if ax is None:
@@ -143,19 +140,22 @@ def _event_sum_batch(args, axes):
143140
def f(_, x):
144141
pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
145142
for i in range(len(axes))])
146-
return 0, _event_sum_prim.bind(*pars)
143+
return 0, _event_sum_prim.bind(*pars, post_num=post_num)
144+
147145
_, outs = scan(f, 0, batch_args)
148146
return outs, 0
149147

150148

149+
_event_sum_prim.def_abstract_eval(_event_sum_abstract)
150+
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
151151
batching.primitive_batchers[_event_sum_prim] = _event_sum_batch
152-
152+
xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
153+
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")
153154

154155
# ---------------------------
155156
# event sum kernel 2
156157
# ---------------------------
157158

158-
159159
_event_sum2_prim = core.Primitive("event_sum2")
160160

161161

extensions/brainpylib/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
__all__ = [
5+
'GPUOperatorNotFound',
6+
]
7+
8+
9+
class GPUOperatorNotFound(Exception):
10+
def __init__(self, name):
11+
super(GPUOperatorNotFound, self).__init__(f'''
12+
GPU operator for "{name}" does not found.
13+
14+
Please compile brainpylib GPU operators with the guidance in the following link:
15+
16+
https://brainpy.readthedocs.io/en/latest/tutorial_advanced/compile_brainpylib.html
17+
''')
18+

0 commit comments

Comments
 (0)