Skip to content

Commit 77306b6

Browse files
committed
Updates
1 parent aa0d58f commit 77306b6

File tree

10 files changed

+186
-41
lines changed

10 files changed

+186
-41
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@
109109
init = initialize
110110
optim = optimizers
111111

112+
from . import experimental
113+
112114

113115
# deprecated
114116
from . import base

brainpy/dyn/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,9 @@ def __init__(
841841

842842
# pre or post neuron group
843843
# ------------------------
844-
if not isinstance(pre, NeuGroup):
844+
if not isinstance(pre, (NeuGroup, DynamicalSystem)):
845845
raise TypeError('"pre" must be an instance of NeuGroup.')
846-
if not isinstance(post, NeuGroup):
846+
if not isinstance(post, (NeuGroup, DynamicalSystem)):
847847
raise TypeError('"post" must be an instance of NeuGroup.')
848848
self.pre = pre
849849
self.post = post

brainpy/dyn/synapses/abstract_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def __init__(
280280
pre: NeuGroup,
281281
post: NeuGroup,
282282
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
283-
output: SynOut = CUBA(),
283+
output: Optional[SynOut] = CUBA(),
284284
stp: Optional[SynSTP] = None,
285285
comp_method: str = 'sparse',
286286
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,

brainpy/experimental/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from .synapses import *

brainpy/experimental/synapses.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from typing import Union, Optional
4+
5+
import brainpylib as bl
6+
import jax
7+
8+
from brainpy import (math as bm,
9+
initialize as init,
10+
connect)
11+
from brainpy.dyn.base import DynamicalSystem, SynSTP
12+
from brainpy.integrators.ode import odeint
13+
from brainpy.types import Initializer, ArrayType
14+
15+
__all__ = [
16+
'Exponential',
17+
]
18+
19+
20+
class Exponential(DynamicalSystem):
21+
def __init__(
22+
self,
23+
conn: connect.TwoEndConnector,
24+
stp: Optional[SynSTP] = None,
25+
g_max: Union[float, Initializer] = 1.,
26+
g_initializer: Union[float, Initializer] = init.ZeroInit(),
27+
tau: Union[float, ArrayType] = 8.0,
28+
method: str = 'exp_auto',
29+
mode: Optional[bm.Mode] = None,
30+
name: Optional[str] = None,
31+
):
32+
super(Exponential, self).__init__(name=name, mode=mode)
33+
34+
# component
35+
self.conn = conn
36+
self.stp = stp
37+
self.g_initializer = g_initializer
38+
assert self.conn.pre_num is not None
39+
assert self.conn.post_num is not None
40+
41+
# parameters
42+
self.tau = tau
43+
if bm.size(self.tau) != 1:
44+
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
45+
46+
# connections and weights
47+
if isinstance(self.conn, connect.One2One):
48+
self.g_max = init.parameter(g_max, (self.conn.pre_num,), allow_none=False)
49+
50+
elif isinstance(self.conn, connect.All2All):
51+
self.g_max = init.parameter(g_max, (self.conn.pre_num, self.conn.post_num), allow_none=False)
52+
53+
else:
54+
self.conn_mask = self.conn.require('pre2post')
55+
self.g_max = init.parameter(g_max, self.conn_mask[0].shape, allow_none=False)
56+
57+
# variables
58+
self.g = init.variable_(g_initializer, self.conn.post_num, self.mode)
59+
60+
# function
61+
self.integral = odeint(lambda g, t: -g / self.tau, method=method)
62+
63+
def reset_state(self, batch_size=None):
64+
self.g.value = init.variable_(bm.zeros, self.conn.post_num, batch_size)
65+
if self.stp is not None:
66+
self.stp.reset_state(batch_size)
67+
68+
def _syn2post_with_one2one(self, syn_value, syn_weight):
69+
return syn_value * syn_weight
70+
71+
def _syn2post_with_all2all(self, syn_value, syn_weight):
72+
if bm.ndim(syn_weight) == 0:
73+
if isinstance(self.mode, bm.BatchingMode):
74+
assert syn_value.ndim == 2
75+
post_vs = bm.sum(syn_value, keepdims=True, axis=1)
76+
else:
77+
post_vs = bm.sum(syn_value)
78+
if not self.conn.include_self:
79+
post_vs = post_vs - syn_value
80+
post_vs = syn_weight * post_vs
81+
else:
82+
assert syn_weight.ndim == 2
83+
if isinstance(self.mode, bm.BatchingMode):
84+
assert syn_value.ndim == 2
85+
post_vs = syn_value @ syn_weight
86+
else:
87+
post_vs = syn_value @ syn_weight
88+
return post_vs
89+
90+
def update(self, tdi, spike):
91+
t, dt = tdi['t'], tdi.get('dt', bm.dt)
92+
93+
# update sub-components
94+
if self.stp is not None:
95+
self.stp.update(tdi, spike)
96+
97+
# post values
98+
if isinstance(self.conn, connect.All2All):
99+
syn_value = bm.asarray(spike, dtype=bm.float_)
100+
if self.stp is not None:
101+
syn_value = self.stp(syn_value)
102+
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
103+
elif isinstance(self.conn, connect.One2One):
104+
syn_value = bm.asarray(spike, dtype=bm.float_)
105+
if self.stp is not None:
106+
syn_value = self.stp(syn_value)
107+
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
108+
else:
109+
if isinstance(self.mode, bm.BatchingMode):
110+
f = jax.vmap(bl.event_ops.event_csr_matvec, in_axes=(None, None, None, 0))
111+
post_vs = f(self.g_max, self.conn_mask[0], self.conn_mask[1], spike,
112+
shape=(self.conn.pre_num, self.conn.post_num), transpose=True)
113+
else:
114+
post_vs = bl.event_ops.event_csr_matvec(
115+
self.g_max, self.conn_mask[0], self.conn_mask[1], spike,
116+
shape=(self.conn.pre_num, self.conn.post_num), transpose=True
117+
)
118+
# updates
119+
self.g.value = self.integral(self.g.value, t, dt) + post_vs
120+
121+
# output
122+
return self.g.value

brainpy/initialize/generic.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ def init_param(
9494

9595

9696
def variable_(
97-
data: Union[Callable, ArrayType],
97+
init: Union[Callable, ArrayType],
9898
size: Shape = None,
9999
batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None,
100100
batch_axis: int = 0,
101101
):
102-
"""Initialize variables. Same as `variable()`.
102+
"""Initialize a :math:`~.Variable` from a callable function or a data.
103103
104104
Parameters
105105
----------
106-
data: callable, function, ArrayType
106+
init: callable, function, ArrayType
107107
The data to be initialized as a ``Variable``.
108108
batch_size_or_mode: int, bool, Mode, optional
109109
The batch size, model ``Mode``, boolean state.
@@ -125,11 +125,11 @@ def variable_(
125125
variable, parameter, noise, delay
126126
127127
"""
128-
return variable(data, batch_size_or_mode, size, batch_axis)
128+
return variable(init, batch_size_or_mode, size, batch_axis)
129129

130130

131131
def variable(
132-
data: Union[Callable, ArrayType],
132+
init: Union[Callable, ArrayType],
133133
batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None,
134134
size: Shape = None,
135135
batch_axis: int = 0,
@@ -138,7 +138,7 @@ def variable(
138138
139139
Parameters
140140
----------
141-
data: callable, function, ArrayType
141+
init: callable, function, ArrayType
142142
The data to be initialized as a ``Variable``.
143143
batch_size_or_mode: int, bool, Mode, optional
144144
The batch size, model ``Mode``, boolean state.
@@ -161,34 +161,34 @@ def variable(
161161
162162
"""
163163
size = to_size(size)
164-
if callable(data):
164+
if callable(init):
165165
if size is None:
166166
raise ValueError('"varshape" cannot be None when data is a callable function.')
167167
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
168-
return bm.Variable(data(size))
168+
return bm.Variable(init(size))
169169
elif isinstance(batch_size_or_mode, bm.BatchingMode):
170170
new_shape = size[:batch_axis] + (1,) + size[batch_axis:]
171-
return bm.Variable(data(new_shape), batch_axis=batch_axis)
171+
return bm.Variable(init(new_shape), batch_axis=batch_axis)
172172
elif batch_size_or_mode in (None, False):
173-
return bm.Variable(data(size))
173+
return bm.Variable(init(size))
174174
elif isinstance(batch_size_or_mode, int):
175175
new_shape = size[:batch_axis] + (int(batch_size_or_mode),) + size[batch_axis:]
176-
return bm.Variable(data(new_shape), batch_axis=batch_axis)
176+
return bm.Variable(init(new_shape), batch_axis=batch_axis)
177177
else:
178178
raise ValueError('Unknown batch_size_or_mode.')
179179

180180
else:
181181
if size is not None:
182-
if bm.shape(data) != size:
183-
raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {size}')
182+
if bm.shape(init) != size:
183+
raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {size}')
184184
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
185-
return bm.Variable(data)
185+
return bm.Variable(init)
186186
elif isinstance(batch_size_or_mode, bm.BatchingMode):
187-
return bm.Variable(bm.expand_dims(data, axis=batch_axis), batch_axis=batch_axis)
187+
return bm.Variable(bm.expand_dims(init, axis=batch_axis), batch_axis=batch_axis)
188188
elif batch_size_or_mode in (None, False):
189-
return bm.Variable(data)
189+
return bm.Variable(init)
190190
elif isinstance(batch_size_or_mode, int):
191-
return bm.Variable(bm.repeat(bm.expand_dims(data, axis=batch_axis),
191+
return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
192192
int(batch_size_or_mode),
193193
axis=batch_axis),
194194
batch_axis=batch_axis)

brainpy/math/_utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Callable
43
import functools
5-
from .ndarray import Array
4+
from typing import Callable
5+
6+
import jax
7+
import numpy as np
68
from jax.tree_util import tree_map
79

10+
from .ndarray import Array
11+
812

913
def wraps(fun: Callable):
1014
"""Specialized version of functools.wraps for wrapping numpy functions.
@@ -15,6 +19,7 @@ def wraps(fun: Callable):
1519
this reason, it is important that parameter names match those in the original
1620
numpy function.
1721
"""
22+
1823
def wrap(op):
1924
docstr = getattr(fun, "__doc__", None)
2025
op.__doc__ = docstr
@@ -27,20 +32,31 @@ def wrap(op):
2732
else:
2833
setattr(op, attr, value)
2934
return op
35+
3036
return wrap
3137

38+
3239
def _as_jax_array(a):
3340
return a.value if isinstance(a, Array) else a
3441

42+
43+
def _as_brainpy_array(a):
44+
return Array(a) if isinstance(a, (np.ndarray, jax.Array)) else a
45+
46+
3547
def _is_leaf(a):
3648
return isinstance(a, Array)
3749

3850

39-
def _compatible_with_brainpy_array(fun):
51+
def _compatible_with_brainpy_array(fun: Callable, return_brainpy_array: bool = False):
4052
@functools.wraps(fun)
4153
def new_fun(*args, **kwargs):
4254
args = tree_map(_as_jax_array, args, is_leaf=_is_leaf)
43-
kwargs = tree_map(_as_jax_array, kwargs, is_leaf=_is_leaf)
44-
return fun(*args, **kwargs)
45-
return new_fun
55+
if len(kwargs):
56+
kwargs = tree_map(_as_jax_array, kwargs, is_leaf=_is_leaf)
57+
r = fun(*args, **kwargs)
58+
return tree_map(_as_brainpy_array, r) if return_brainpy_array else r
59+
60+
new_fun.__doc__ = getattr(fun, "__doc__", None)
4661

62+
return new_fun

brainpy/math/operators/pre_syn_post.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from brainpy.errors import MathError
1010
from brainpy.math.numpy_ops import as_jax
11-
from brainpy.types import ArrayType
1211

1312
__all__ = [
1413
# pre-to-post
@@ -43,10 +42,10 @@ def _raise_pre_ids_is_none(pre_ids):
4342
f'(brainpy.math.ndim(pre_values) != 0).')
4443

4544

46-
def pre2post_event_sum(events: ArrayType,
47-
pre2post: Tuple[ArrayType, ArrayType],
45+
def pre2post_event_sum(events,
46+
pre2post,
4847
post_num: int,
49-
values: Union[float, ArrayType] = 1.):
48+
values = 1.):
5049
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.
5150
5251
When ``values`` is a scalar, this function is equivalent to
@@ -103,11 +102,11 @@ def pre2post_event_sum(events: ArrayType,
103102
transpose=True)
104103

105104

106-
def pre2post_coo_event_sum(events: ArrayType,
107-
pre_ids: ArrayType,
108-
post_ids: ArrayType,
105+
def pre2post_coo_event_sum(events,
106+
pre_ids,
107+
post_ids,
109108
post_num: int,
110-
values: Union[float, ArrayType] = 1.):
109+
values = 1.):
111110
"""The pre-to-post synaptic computation with event-driven summation.
112111
113112
Parameters

brainpy/math/operators/sparse_matmul.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from brainpy.math.ndarray import Array
1111
from brainpy.math.numpy_ops import as_jax
12-
from brainpy.types import ArrayType
1312

1413
__all__ = [
1514
'sparse_matmul',
@@ -18,10 +17,10 @@
1817
]
1918

2019

21-
def event_csr_matvec(values: ArrayType,
22-
indices: ArrayType,
23-
indptr: ArrayType,
24-
events: ArrayType,
20+
def event_csr_matvec(values,
21+
indices,
22+
indptr,
23+
events,
2524
shape: Tuple[int, int],
2625
transpose: bool = False):
2726
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.

brainpy/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import TypeVar, Tuple
3+
from typing import TypeVar, Tuple, Union, Callable
44

55
import jax.numpy as jnp
66
import numpy as np
77

88
from brainpy.math.ndarray import Array, Variable, TrainVar
9+
from brainpy import connect as conn
10+
from brainpy import initialize as init
911

1012
__all__ = [
1113
'ArrayType', 'Parameter', 'PyTree',
12-
'Shape',
14+
'Shape', 'Initializer',
1315
'Output', 'Monitor'
1416
]
1517

@@ -26,4 +28,6 @@
2628
# component
2729
Output = TypeVar('Output') # noqa
2830
Monitor = TypeVar('Monitor') # noqa
31+
Connector = Union[conn.Connector, Array, Variable, jnp.ndarray, np.ndarray]
32+
Initializer = Union[init.Initializer, Callable, Array, Variable, jnp.ndarray, np.ndarray]
2933

0 commit comments

Comments
 (0)