Skip to content

Commit 81f4720

Browse files
authored
Merge pull request #337 from chaoming0625/master
Recent updates
2 parents 3184caf + 2776056 commit 81f4720

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2128
-590
lines changed

.github/workflows/Sync_branches.yml

Lines changed: 0 additions & 18 deletions
This file was deleted.

brainpy/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# convenient alias
3636
conn = connect
3737
init = initialize
38-
optimizers = optim
38+
globals()['optimizers'] = optim
3939

4040
# numerical integrators
4141
from brainpy import integrators
@@ -58,8 +58,11 @@
5858
synapses, # synaptic dynamics
5959
synouts, # synaptic output
6060
synplast, # synaptic plasticity
61+
experimental, # experimental model
6162
)
63+
from brainpy._src.dyn.base import not_pass_shargs
6264
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
65+
Module as Module,
6366
Container as Container,
6467
Sequential as Sequential,
6568
Network as Network,
@@ -71,7 +74,6 @@
7174
TwoEndConn as TwoEndConn,
7275
CondNeuGroup as CondNeuGroup,
7376
Channel as Channel)
74-
from brainpy._src.dyn.base import (DSPartial as DSPartial)
7577
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
7678
LoopOverTime as LoopOverTime,)
7779
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner

brainpy/_src/checkpoints/serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,8 @@ def save_pytree(
12641264

12651265
if os.path.splitext(filename)[-1] != '.bp':
12661266
filename = filename + '.bp'
1267-
os.makedirs(os.path.dirname(filename), exist_ok=True)
1267+
if os.path.dirname(filename):
1268+
os.makedirs(os.path.dirname(filename), exist_ok=True)
12681269
if not overwrite and os.path.exists(filename):
12691270
raise InvalidCheckpointPath(filename)
12701271
target = to_bytes(target)

brainpy/_src/dyn/base.py

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,28 @@
22

33
import collections
44
import gc
5-
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
5+
import warnings
6+
from typing import Union, Dict, Callable, Sequence, Optional, Tuple
67

78
import jax
89
import jax.numpy as jnp
910
import numpy as np
1011

11-
from brainpy import tools, check
12+
from brainpy import tools
1213
from brainpy._src import math as bm
13-
from brainpy._src.math.ndarray import Variable, VariableView
14-
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
1514
from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
1615
from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
1716
from brainpy._src.integrators import odeint, sdeint
18-
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
17+
from brainpy._src.math.ndarray import Variable, VariableView
18+
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
1919
from brainpy.errors import NoImplementationError, UnsupportedError
2020
from brainpy.types import ArrayType, Shape
2121

2222
__all__ = [
2323
# general class
2424
'DynamicalSystem',
25+
'Module',
2526
'FuncAsDynSys',
26-
'DSPartial',
2727

2828
# containers
2929
'Container', 'Network', 'Sequential', 'System',
@@ -48,6 +48,46 @@
4848
SLICE_VARS = 'slice_vars'
4949

5050

51+
def not_pass_shargs(func: Callable):
52+
"""Label the update function as the one without passing shared arguments.
53+
54+
The original update function explicitly requires shared arguments at the first place::
55+
56+
class TheModel(DynamicalSystem):
57+
def update(self, s, x):
58+
# s is the shared arguments, like `t`, `dt`, etc.
59+
pass
60+
61+
So, each time we call the model we should provide shared arguments into the model::
62+
63+
TheModel()(shared, inputs)
64+
65+
When we label the update function as ``do_not_pass_sha_args``, this time there is no
66+
need to call the dynamical system with shared arguments::
67+
68+
class NewModel(DynamicalSystem):
69+
@no_shared
70+
def update(self, x):
71+
pass
72+
73+
NewModel()(inputs)
74+
75+
.. versionadded:: 2.3.5
76+
77+
Parameters
78+
----------
79+
func: Callable
80+
The function in the :py:class:`~.DynamicalSystem`.
81+
82+
Returns
83+
-------
84+
func: Callable
85+
The wrapped function for the class.
86+
"""
87+
func._new_style = True
88+
return func
89+
90+
5191
class DynamicalSystem(BrainPyObject):
5292
"""Base Dynamical System class.
5393
@@ -65,7 +105,6 @@ class DynamicalSystem(BrainPyObject):
65105
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
66106
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
67107
68-
69108
Parameters
70109
----------
71110
name : optional, str
@@ -74,12 +113,6 @@ class DynamicalSystem(BrainPyObject):
74113
The model computation mode. It should be instance of :py:class:`~.Mode`.
75114
"""
76115

77-
online_fit_by: Optional[OnlineAlgorithm]
78-
'''Online fitting method.'''
79-
80-
offline_fit_by: Optional[OfflineAlgorithm]
81-
'''Offline fitting method.'''
82-
83116
global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict()
84117
'''Global delay data, which stores the delay variables and corresponding delay targets.
85118
This variable is useful when the same target variable is used in multiple mappings,
@@ -97,15 +130,11 @@ def __init__(
97130
f'but we got {type(mode)}: {mode}')
98131
self._mode = mode
99132

100-
super(DynamicalSystem, self).__init__(name=name)
101-
102133
# local delay variables
103134
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
104135

105-
# fitting parameters
106-
self.online_fit_by = None
107-
self.offline_fit_by = None
108-
self.fit_record = dict()
136+
# super initialization
137+
super(DynamicalSystem, self).__init__(name=name)
109138

110139
@property
111140
def mode(self) -> bm.Mode:
@@ -124,7 +153,21 @@ def __repr__(self):
124153

125154
def __call__(self, *args, **kwargs):
126155
"""The shortcut to call ``update`` methods."""
127-
return self.update(*args, **kwargs)
156+
if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
157+
if len(args) and isinstance(args[0], dict):
158+
bm.share.save_shargs(**args[0])
159+
return self.update(*args[1:], **kwargs)
160+
else:
161+
return self.update(*args, **kwargs)
162+
else:
163+
if len(args) and isinstance(args[0], dict):
164+
return self.update(*args, **kwargs)
165+
else:
166+
# If first argument is not shared argument,
167+
# we should get the shared arguments from the global context.
168+
# However, users should set and update shared arguments
169+
# in the global context when using this mode.
170+
return self.update(bm.share.get_shargs(), *args, **kwargs)
128171

129172
def register_delay(
130173
self,
@@ -339,26 +382,13 @@ def __del__(self):
339382
del self.__dict__[key]
340383
gc.collect()
341384

342-
@tools.not_customized
343-
def online_init(self):
344-
raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.')
345-
346-
@tools.not_customized
347-
def online_fit(self,
348-
target: ArrayType,
349-
fit_record: Dict[str, ArrayType]):
350-
raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.')
351-
352-
@tools.not_customized
353-
def offline_fit(self,
354-
target: ArrayType,
355-
fit_record: Dict[str, ArrayType]):
356-
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')
357-
358385
def clear_input(self):
359386
pass
360387

361388

389+
Module = DynamicalSystem
390+
391+
362392
class FuncAsDynSys(DynamicalSystem):
363393
"""Transform a Python function as a :py:class:`~.DynamicalSystem`
364394
@@ -411,31 +441,6 @@ def __repr__(self):
411441
f'{indent}num_of_vars={len(self.implicit_vars)})')
412442

413443

414-
class DSPartial(FuncAsDynSys):
415-
def __init__(
416-
self,
417-
target: Callable,
418-
*args,
419-
child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None,
420-
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
421-
shared: Dict = None,
422-
**keywords
423-
):
424-
super().__init__(target=target, child_objs=child_objs, dyn_vars=dyn_vars)
425-
426-
check.is_dict_data(shared, all_none=True)
427-
self.target = check.is_callable(target, )
428-
self.args = tuple(args)
429-
self.keywords = keywords
430-
self.shared = dict() if shared is None else shared
431-
432-
def __call__(self, s, *args, **keywords):
433-
assert isinstance(s, dict)
434-
s = tools.DotDict(s).update(self.shared)
435-
args = self.args + (s,) + args
436-
keywords = {**self.keywords, **keywords}
437-
return self.target(*args, **keywords)
438-
439444

440445
class Container(DynamicalSystem):
441446
"""Container object which is designed to add other instances of DynamicalSystem.
@@ -639,7 +644,7 @@ def __repr__(self):
639644
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules))
640645
return f'{self.__class__.__name__}(\n{entries}\n)'
641646

642-
def update(self, *args) -> ArrayType:
647+
def update(self, s, x) -> ArrayType:
643648
"""Update function of a sequential model.
644649
645650
Parameters
@@ -654,7 +659,6 @@ def update(self, *args) -> ArrayType:
654659
y: ArrayType
655660
The output tensor.
656661
"""
657-
s, x = (dict(), args[0]) if len(args) == 1 else (args[0], args[1])
658662
for m in self._modules:
659663
if isinstance(m, DynamicalSystem):
660664
x = m(s, x)
@@ -818,7 +822,7 @@ def get_batch_shape(self, batch_size=None):
818822
else:
819823
return (batch_size,) + self.varshape
820824

821-
def update(self, tdi, x=None):
825+
def update(self, *args):
822826
"""The function to specify the updating rule.
823827
824828
Parameters

brainpy/_src/dyn/layers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional
55

66
import brainpy.math as bm
7-
from brainpy._src.dyn.base import DynamicalSystem
7+
from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs
88

99
__all__ = [
1010
'Layer'

brainpy/_src/dyn/layers/conv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from jax import lax
66

77
from brainpy import math as bm, tools, check
8+
from brainpy._src.dyn.base import not_pass_shargs
89
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
910
from brainpy.types import ArrayType
1011
from .base import Layer
@@ -153,8 +154,8 @@ def _check_input_dim(self, x):
153154
raise ValueError(f"input channels={x.shape[-1]} needs to have "
154155
f"the same size as in_channels={self.in_channels}.")
155156

156-
def update(self, *args):
157-
x = args[0] if len(args) == 1 else args[1]
157+
@not_pass_shargs
158+
def update(self, x):
158159
self._check_input_dim(x)
159160
w = self.w.value
160161
if self.mask is not None:
@@ -525,8 +526,8 @@ def __init__(
525526
def _check_input_dim(self, x):
526527
raise NotImplementedError
527528

528-
def update(self, *args):
529-
x = args[0] if len(args) == 1 else args[1]
529+
@not_pass_shargs
530+
def update(self, x):
530531
self._check_input_dim(x)
531532

532533
w = self.w.value

brainpy/_src/dyn/layers/dropout.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3-
import jax.numpy as jnp
4-
53

64
from brainpy import math as bm, check
75
from .base import Layer
6+
from brainpy._src.dyn.base import not_pass_shargs
87

98
__all__ = [
109
'Dropout'
@@ -49,8 +48,8 @@ def __init__(
4948
self.prob = check.is_float(prob, min_bound=0., max_bound=1.)
5049
self.rng = bm.random.default_rng(seed)
5150

52-
def update(self, sha, x):
53-
if sha.get('fit', True):
51+
def update(self, s, x):
52+
if s['fit']:
5453
keep_mask = self.rng.bernoulli(self.prob, x.shape)
5554
return bm.where(bm.as_jax(keep_mask), x / self.prob, 0.)
5655
else:

0 commit comments

Comments
 (0)