Skip to content

Commit 239088f

Browse files
committed
upgrade object_transform
1 parent 7d88dac commit 239088f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+447
-230
lines changed

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.2.4.1"
3+
__version__ = "2.3.0"
44

55

66
# fundamental modules

brainpy/analysis/highdim/tests/test_slow_points.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def step(s):
105105

106106
finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
107107
finder.find_fps_with_opt_solver(bm.random.random((100, 2)))
108+
bm.clear_buffer_memory()
108109

109110
def test_opt_solver_for_ds1(self):
110111
hh = HH(1)
@@ -117,6 +118,7 @@ def test_opt_solver_for_ds1(self):
117118
'm': bm.random.random((100, 1)),
118119
'h': bm.random.random((100, 1)),
119120
'n': bm.random.random((100, 1))})
121+
bm.clear_buffer_memory()
120122

121123
def test_gd_method_for_func1(self):
122124
gamma = 0.641 # Saturation factor for gating variable
@@ -149,6 +151,7 @@ def step(s):
149151

150152
finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
151153
finder.find_fps_with_gd_method(bm.random.random((100, 2)), num_opt=100)
154+
bm.clear_buffer_memory()
152155

153156
def test_gd_method_for_func2(self):
154157
hh = HH(1)
@@ -162,4 +165,5 @@ def test_gd_method_for_func2(self):
162165
'h': bm.random.random((100, 1)),
163166
'n': bm.random.random((100, 1))},
164167
num_opt=100)
168+
bm.clear_buffer_memory()
165169

brainpy/base/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
and its associated helper class ``Collector`` and ``ArrayCollector``.
88
- For each instance of "BrainPyObject" class, users can retrieve all
99
the variables (or trainable variables), integrators, and nodes.
10-
- This module also provides a ``FunAsAObject`` class to wrap user-defined
10+
- This module also provides a ``FunAsObject`` class to wrap user-defined
1111
functions. In each function, maybe several nodes are used, and
12-
users can initialize a ``FunAsAObject`` by providing the nodes used
13-
in the function. Unfortunately, ``FunAsAObject`` class does not have
12+
users can initialize a ``FunAsObject`` by providing the nodes used
13+
in the function. Unfortunately, ``FunAsObject`` class does not have
1414
the ability to gather nodes automatically.
1515
- This module provides ``io`` helper functions to help users save/load
1616
model states, or share user's customized model with others.

brainpy/base/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ class BrainPyObject(object):
2424
2525
- ``DynamicalSystem`` in *brainpy.dyn.base.py*
2626
- ``Integrator`` in *brainpy.integrators.base.py*
27-
- ``FunAsAObject`` in *brainpy.base.function.py*
27+
- ``FunAsObject`` in *brainpy.base.function.py*
2828
- ``Optimizer`` in *brainpy.optimizers.py*
2929
- ``Scheduler`` in *brainpy.optimizers.py*
30-
30+
- and others.
3131
"""
3232

3333
_excluded_vars = ()

brainpy/base/collector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,12 @@ def dict(self):
186186
def data(self):
187187
"""Get all data in each value."""
188188
return [x.value for x in self.values()]
189+
190+
@classmethod
191+
def from_other(cls, other: Union[Sequence, Dict]):
192+
if isinstance(other, (tuple, list)):
193+
return cls({id(o): o for o in other})
194+
elif isinstance(other, dict):
195+
return cls(other)
196+
else:
197+
raise TypeError

brainpy/base/function.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# -*- coding: utf-8 -*-
22

33
from typing import Optional, Callable
4+
45
from brainpy import errors
5-
from brainpy.base.base import BrainPyObject
66
from brainpy.base import collector
7+
from brainpy.base.base import BrainPyObject
78

89
math = None
910

1011
__all__ = [
11-
'FunAsAObject',
12+
'FunAsObject',
1213
]
1314

1415

@@ -26,7 +27,7 @@ def _check_var(var):
2627
f'{math.ndarray.__name__}, but we got {type(var)}.')
2728

2829

29-
class FunAsAObject(BrainPyObject):
30+
class FunAsObject(BrainPyObject):
3031
"""The wrapper for Python functions.
3132
3233
Parameters
@@ -46,8 +47,8 @@ def __init__(self, f: Optional[Callable], child_objs=None, dyn_vars=None, name=N
4647
# ---
4748
self._f = f
4849
if name is None:
49-
name = self.unique_name(type_=f.__name__ if hasattr(f, '__name__') else 'FunAsAObject')
50-
super(FunAsAObject, self).__init__(name=name)
50+
name = self.unique_name(type_=f.__name__ if hasattr(f, '__name__') else 'FunAsObject')
51+
super(FunAsObject, self).__init__(name=name)
5152

5253
# nodes
5354
# ---
@@ -89,3 +90,19 @@ def __init__(self, f: Optional[Callable], child_objs=None, dyn_vars=None, name=N
8990

9091
def __call__(self, *args, **kwargs):
9192
return self._f(*args, **kwargs)
93+
94+
def __repr__(self):
95+
name = self.__class__.__name__
96+
# indent = ' ' * (len(name) + 10)
97+
# child_nodes = ['\n'.join([('' if i == 0 else indent) + l for i, l in enumerate(repr(node).split('\n'))])
98+
# for node in self.implicit_nodes.values()]
99+
# first_line = f'{name}(objects=['
100+
# format_res = (
101+
# first_line +
102+
# (',\n' + ' ' * (len(name) + 10)).join(child_nodes) +
103+
# '],\n'
104+
# + (" " * (len(name) + 1)) + f'number of variables = {len(self.implicit_vars)})'
105+
# )
106+
format_ref = (f'{name}(nodes=[{", ".join([n.name for n in tuple(self.implicit_nodes.values())])}],\n' +
107+
" " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')
108+
return format_ref

brainpy/dyn/base.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,16 @@ class DynamicalSystem(BrainPyObject):
5959
The model computation mode. It should be instance of :py:class:`~.Mode`.
6060
"""
6161

62-
'''Online fitting method.'''
6362
online_fit_by: Optional[OnlineAlgorithm]
63+
'''Online fitting method.'''
6464

65-
'''Offline fitting method.'''
6665
offline_fit_by: Optional[OfflineAlgorithm]
66+
'''Offline fitting method.'''
6767

68+
global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], bm.Variable]] = dict()
6869
'''Global delay data, which stores the delay variables and corresponding delay targets.
69-
7070
This variable is useful when the same target variable is used in multiple mappings,
7171
as it can reduce the duplicate delay variable registration.'''
72-
global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], bm.Variable]] = dict()
7372

7473
def __init__(
7574
self,
@@ -435,15 +434,45 @@ def clear_input(self):
435434
node.clear_input()
436435

437436

438-
class Sequential(Container):
437+
class Sequential(DynamicalSystem):
439438
def __init__(
440439
self,
441440
*modules,
442441
name: str = None,
443442
mode: Mode = normal,
444443
**kw_modules
445444
):
446-
super(Sequential, self).__init__(*modules, name=name, mode=mode, **kw_modules)
445+
super().__init__(name=name, mode=mode)
446+
self._modules = tuple(modules) + tuple(kw_modules.values())
447+
448+
seq_modules = [m for m in modules if isinstance(m, BrainPyObject)]
449+
dict_modules = {k: m for k, m in kw_modules.items() if isinstance(m, BrainPyObject)}
450+
451+
# add tuple-typed components
452+
for module in seq_modules:
453+
if isinstance(module, BrainPyObject):
454+
self.implicit_nodes[module.name] = module
455+
elif isinstance(module, (list, tuple)):
456+
for m in module:
457+
if not isinstance(m, BrainPyObject):
458+
raise ValueError(f'Should be instance of {BrainPyObject.__name__}. '
459+
f'But we got {type(m)}')
460+
self.implicit_nodes[m.name] = module
461+
elif isinstance(module, dict):
462+
for k, v in module.items():
463+
if not isinstance(v, BrainPyObject):
464+
raise ValueError(f'Should be instance of {BrainPyObject.__name__}. '
465+
f'But we got {type(v)}')
466+
self.implicit_nodes[k] = v
467+
else:
468+
raise ValueError(f'Cannot parse sub-systems. They should be {BrainPyObject.__name__} '
469+
f'or a list/tuple/dict of {BrainPyObject.__name__}.')
470+
# add dict-typed components
471+
for k, v in dict_modules.items():
472+
if not isinstance(v, BrainPyObject):
473+
raise ValueError(f'Should be instance of {BrainPyObject.__name__}. '
474+
f'But we got {type(v)}')
475+
self.implicit_nodes[k] = v
447476

448477
def __getattr__(self, item):
449478
"""Wrap the dot access ('self.'). """
@@ -463,7 +492,7 @@ def __getitem__(self, key: Union[int, slice]):
463492
components = tuple(self.implicit_nodes.values())[key]
464493
return Sequential(dict(zip(keys, components)))
465494
elif isinstance(key, int):
466-
return self.implicit_nodes.values()[key]
495+
return tuple(self.implicit_nodes.values())[key]
467496
elif isinstance(key, (tuple, list)):
468497
all_keys = tuple(self.implicit_nodes.keys())
469498
all_vals = tuple(self.implicit_nodes.values())
@@ -478,27 +507,7 @@ def __getitem__(self, key: Union[int, slice]):
478507
raise KeyError(f'Unknown type of key: {type(key)}')
479508

480509
def __repr__(self):
481-
def f(x):
482-
if not isinstance(x, DynamicalSystem) and callable(x):
483-
signature = inspect.signature(x)
484-
args = [f'{k}={v.default}' for k, v in signature.parameters.items()
485-
if v.default is not inspect.Parameter.empty]
486-
args = ', '.join(args)
487-
while not hasattr(x, '__name__'):
488-
if not hasattr(x, 'func'):
489-
break
490-
x = x.func # Handle functools.partial
491-
if not hasattr(x, '__name__') and hasattr(x, '__class__'):
492-
return x.__class__.__name__
493-
if args:
494-
return f'{x.__name__}(*, {args})'
495-
return x.__name__
496-
else:
497-
x = repr(x).split('\n')
498-
x = [x[0]] + [' ' + y for y in x[1:]]
499-
return '\n'.join(x)
500-
501-
entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self))
510+
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules))
502511
return f'{self.__class__.__name__}(\n{entries}\n)'
503512

504513
def update(self, sha: dict, x: Any) -> Array:
@@ -516,8 +525,11 @@ def update(self, sha: dict, x: Any) -> Array:
516525
y: Array
517526
The output tensor.
518527
"""
519-
for node in self.implicit_nodes.values():
520-
x = node(sha, x)
528+
for m in self._modules:
529+
if isinstance(m, DynamicalSystem):
530+
x = m(sha, x)
531+
else:
532+
x = m(x)
521533
return x
522534

523535

brainpy/dyn/synapses/abstract_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Delta(TwoEndConn):
5454
>>>
5555
>>> neu1 = neurons.LIF(1)
5656
>>> neu2 = neurons.LIF(1)
57-
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), weights=5.)
57+
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), g_max=5.)
5858
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
5959
>>>
6060
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike'])

brainpy/dyn/synapses/biological_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ class BioNMDA(TwoEndConn):
412412
>>>
413413
>>> neu1 = neurons.HH(1)
414414
>>> neu2 = neurons.HH(1)
415-
>>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All(), E=0.)
415+
>>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All())
416416
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
417417
>>>
418418
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])

brainpy/dyn/synapses/learning_rules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ class STP(TwoEndConn):
7676
>>> import brainpy as bp
7777
>>> import matplotlib.pyplot as plt
7878
>>>
79-
>>> neu1 = bp.dyn.LIF(1)
80-
>>> neu2 = bp.dyn.LIF(1)
81-
>>> syn1 = bp.dyn.STP(neu1, neu2, bp.connect.All2All(), U=0.2, tau_d=150., tau_f=2.)
79+
>>> neu1 = bp.neurons.LIF(1)
80+
>>> neu2 = bp.neurons.LIF(1)
81+
>>> syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.2, tau_d=150., tau_f=2.)
8282
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
8383
>>>
8484
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x'])

0 commit comments

Comments
 (0)