Skip to content

Commit 454969b

Browse files
authored
Merge pull request #344 from chaoming0625/master
The update and fix of functions and models
2 parents 0a519c0 + d2bd305 commit 454969b

File tree

19 files changed

+801
-185
lines changed

19 files changed

+801
-185
lines changed

brainpy/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.3.6"
3+
__version__ = "2.3.7"
44

55

66
# fundamental supporting modules
@@ -61,20 +61,23 @@
6161
experimental,
6262
)
6363
from brainpy._src.dyn.base import not_pass_shared
64-
from brainpy._src.dyn.base import (DynamicalSystem,
65-
DynamicalSystemNS,
64+
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
6665
Container as Container,
6766
Sequential as Sequential,
6867
Network as Network,
6968
NeuGroup as NeuGroup,
70-
NeuGroupNS as NeuGroupNS,
7169
SynConn as SynConn,
7270
SynOut as SynOut,
7371
SynSTP as SynSTP,
7472
SynLTP as SynLTP,
7573
TwoEndConn as TwoEndConn,
7674
CondNeuGroup as CondNeuGroup,
7775
Channel as Channel)
76+
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
77+
NeuGroupNS as NeuGroupNS)
78+
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
79+
SynSTPNS as SynSTPNS,
80+
SynConnNS as SynConnNS, )
7881
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
7982
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
8083
from brainpy._src.dyn.context import share, Delay

brainpy/_src/dyn/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,20 @@ def __del__(self):
400400
def clear_input(self):
401401
pass
402402

403+
def __rrshift__(self, other):
404+
"""Support using right shift operator to call modules.
405+
406+
Examples
407+
--------
408+
409+
>>> import brainpy as bp
410+
>>> x = bp.math.random.rand((10, 10))
411+
>>> l = bp.layers.Activation('tanh')
412+
>>> y = x >> l
413+
414+
"""
415+
return self.__call__(other)
416+
403417

404418
class DynamicalSystemNS(DynamicalSystem):
405419
"""Dynamical system without the need of shared parameters passing into ``update()`` function."""

brainpy/_src/dyn/layers/normalization.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,18 @@ def update(self, x):
130130
x = bm.as_jax(x)
131131

132132
if share.load('fit'):
133-
mean = jnp.mean(x, self.axis)
134-
mean_of_square = jnp.mean(_square(x), self.axis)
135-
if self.axis_name is not None:
136-
mean, mean_of_square = jnp.split(lax.pmean(jnp.concatenate([mean, mean_of_square]),
137-
axis_name=self.axis_name,
138-
axis_index_groups=self.axis_index_groups),
139-
2)
140-
var = jnp.maximum(0., mean_of_square - _square(mean))
141-
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
142-
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
133+
mean = jnp.mean(x, self.axis)
134+
mean_of_square = jnp.mean(_square(x), self.axis)
135+
if self.axis_name is not None:
136+
mean, mean_of_square = jnp.split(
137+
lax.pmean(jnp.concatenate([mean, mean_of_square]),
138+
axis_name=self.axis_name,
139+
axis_index_groups=self.axis_index_groups),
140+
2
141+
)
142+
var = jnp.maximum(0., mean_of_square - _square(mean))
143+
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
144+
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
143145
else:
144146
mean = self.running_mean.value
145147
var = self.running_var.value
@@ -488,7 +490,7 @@ def __init__(
488490
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape))
489491
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))
490492

491-
def update(self,x):
493+
def update(self, x):
492494
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
493495
raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), '
494496
f'but we got {x.shape}')
@@ -629,6 +631,8 @@ def __init__(
629631
scale_initializer=scale_initializer,
630632
mode=mode,
631633
name=name)
634+
635+
632636
BatchNorm1D = BatchNorm1d
633637
BatchNorm2D = BatchNorm2d
634-
BatchNorm3D = BatchNorm3d
638+
BatchNorm3D = BatchNorm3d

brainpy/_src/dyn/neurons/biological_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from brainpy import check
77
from brainpy._src.dyn.base import NeuGroupNS
88
from brainpy._src.dyn.context import share
9-
from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_
9+
from brainpy._src.initialize import (OneInit,
10+
Uniform,
11+
Initializer,
12+
parameter,
13+
noise as init_noise,
14+
variable_)
1015
from brainpy._src.integrators.joint_eq import JointEq
1116
from brainpy._src.integrators.ode.generic import odeint
1217
from brainpy._src.integrators.sde.generic import sdeint

brainpy/_src/dyn/neurons/input_groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
from brainpy._src.dyn.context import share
77
import brainpy.math as bm
8-
from brainpy._src.dyn.base import NeuGroupNS, not_pass_shared
8+
from brainpy._src.dyn.base import NeuGroupNS
99
from brainpy._src.initialize import Initializer, parameter, variable_
1010
from brainpy.types import Shape, ArrayType
1111

brainpy/_src/dyn/neurons/noise_groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
from brainpy._src.dyn.context import share
77
from brainpy import math as bm, initialize as init
8-
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared
8+
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup
99
from brainpy._src.initialize import Initializer
1010
from brainpy._src.integrators.sde.generic import sdeint
1111
from brainpy.types import ArrayType, Shape

0 commit comments

Comments
 (0)