Skip to content

Commit e871ed8

Browse files
authored
Merge pull request #124 from PKU-NIP-Lab/changes
fix DOGDecay bugs; add more features
2 parents ef9ab20 + 8210377 commit e871ed8

32 files changed

+1113
-524
lines changed

brainpy/analysis/lowdim/tests/test_phase_plane.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def int_x(x, t, Iext):
1818
analyzer = bp.analysis.PhasePlane1D(model=int_x,
1919
target_vars={'x': [-2, 2]},
2020
pars_update={'Iext': 0.},
21-
resolutions=0.001)
21+
resolutions=0.01)
2222

2323
plt.ion()
2424
analyzer.plot_vector_field()
@@ -60,7 +60,7 @@ def int_s2(s2, t, s1):
6060
analyzer = bp.analysis.PhasePlane2D(
6161
model=[int_s1, int_s2],
6262
target_vars={'s1': [0, 1], 's2': [0, 1]},
63-
resolutions=0.0005
63+
resolutions=0.001
6464
)
6565
plt.ion()
6666
analyzer.plot_vector_field()

brainpy/base/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def vars(self, method='absolute', level=-1, include_self=True):
6969
----------
7070
method : str
7171
The method to access the variables.
72+
level: int
73+
The hierarchy level to find variables.
74+
include_self: bool
75+
Whether include the variables in the self.
7276
7377
Returns
7478
-------
@@ -95,6 +99,10 @@ def train_vars(self, method='absolute', level=-1, include_self=True):
9599
----------
96100
method : str
97101
The method to access the variables. Support 'absolute' and 'relative'.
102+
level: int
103+
The hierarchy level to find TrainVar instances.
104+
include_self: bool
105+
Whether include the TrainVar instances in the self.
98106
99107
Returns
100108
-------
@@ -109,7 +117,7 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p
109117
if _paths is None:
110118
_paths = set()
111119
gather = Collector()
112-
if (level > 0) and (_lid >= level):
120+
if (level > -1) and (_lid >= level):
113121
return gather
114122
if method == 'absolute':
115123
nodes = []
@@ -127,7 +135,7 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p
127135
gather[node.name] = node
128136
nodes.append(node)
129137
for v in nodes:
130-
gather.update(v._find_nodes(method=method, level=-1, _lid=_lid + 1, _paths=_paths,
138+
gather.update(v._find_nodes(method=method, level=level, _lid=_lid + 1, _paths=_paths,
131139
include_self=include_self))
132140
if include_self: gather[self.name] = self
133141

@@ -163,8 +171,10 @@ def nodes(self, method='absolute', level=-1, include_self=True):
163171
----------
164172
method : str
165173
The method to access the nodes.
166-
_paths : set, Optional
167-
The data structure to solve the circular reference.
174+
level: int
175+
The hierarchy level to find nodes.
176+
include_self: bool
177+
Whether include the self.
168178
169179
Returns
170180
-------

brainpy/base/naming.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3-
from brainpy import errors
4-
53
import logging
64

7-
logger = logging.getLogger('brainpy.base.naming')
5+
from brainpy import errors
86

7+
logger = logging.getLogger('brainpy.base.naming')
98

109
__all__ = [
1110
'check_name_uniqueness',
@@ -25,11 +24,13 @@ def check_name_uniqueness(name, obj):
2524
f'Please choose another name.')
2625
if name in _name2id:
2726
if _name2id[name] != id(obj):
28-
raise errors.UniqueNameError(f'In BrainPy, each object should have a unique name. '
29-
f'However, we detect that {obj} has a used name "{name}". \n\n'
30-
f'If you try t0 run multiple trials, you may need '
31-
f'"brainpy.base.clear_name_cache()" '
32-
f'to clear all used names. ')
27+
raise errors.UniqueNameError(
28+
f'In BrainPy, each object should have a unique name. '
29+
f'However, we detect that {obj} has a used name "{name}". \n'
30+
f'If you try to run multiple trials, you may need \n\n'
31+
f'>>> brainpy.base.clear_name_cache() \n\n'
32+
f'to clear all cached names. '
33+
)
3334
else:
3435
_name2id[name] = id(obj)
3536

brainpy/base/tests/test_base.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import unittest
4+
5+
import brainpy as bp
6+
import brainpy.math as bm
7+
8+
9+
class TestCollectionFunction(unittest.TestCase):
10+
def test_f_nodes(self):
11+
class C(bp.dyn.DynamicalSystem):
12+
def __init__(self):
13+
super(C, self).__init__()
14+
15+
class B(bp.dyn.DynamicalSystem):
16+
def __init__(self):
17+
super(B, self).__init__()
18+
19+
self.child1 = C()
20+
self.child2 = C()
21+
22+
class A(bp.dyn.DynamicalSystem):
23+
def __init__(self):
24+
super(A, self).__init__()
25+
26+
self.child1 = B()
27+
self.child2 = B()
28+
29+
net = bp.dyn.Network(a1=A(), a2=A())
30+
print(net.nodes(level=2))
31+
self.assertTrue(len(net.nodes(level=0)) == 0)
32+
self.assertTrue(len(net.nodes(level=0, include_self=False)) == 0)
33+
self.assertTrue(len(net.nodes(level=1)) == (1 + 2))
34+
self.assertTrue(len(net.nodes(level=1, include_self=False)) == 2)
35+
self.assertTrue(len(net.nodes(level=2)) == (1 + 2 + 4))
36+
self.assertTrue(len(net.nodes(level=2, include_self=False)) == (2 + 4))
37+
self.assertTrue(len(net.nodes(level=3)) == (1 + 2 + 4 + 8))
38+
self.assertTrue(len(net.nodes(level=3, include_self=False)) == (2 + 4 + 8))
39+
40+
def test_f_vars(self):
41+
class C(bp.dyn.DynamicalSystem):
42+
def __init__(self):
43+
super(C, self).__init__()
44+
45+
self.var1 = bm.Variable(bm.zeros(1))
46+
self.var2 = bm.Variable(bm.zeros(1))
47+
48+
class B(bp.dyn.DynamicalSystem):
49+
def __init__(self):
50+
super(B, self).__init__()
51+
52+
self.child1 = C()
53+
self.child2 = C()
54+
55+
self.var1 = bm.Variable(bm.zeros(1))
56+
self.var2 = bm.Variable(bm.zeros(1))
57+
58+
class A(bp.dyn.DynamicalSystem):
59+
def __init__(self):
60+
super(A, self).__init__()
61+
62+
self.child1 = B()
63+
self.child2 = B()
64+
65+
self.var1 = bm.Variable(bm.zeros(1))
66+
self.var2 = bm.Variable(bm.zeros(1))
67+
68+
net = bp.dyn.Network(a1=A(), a2=A())
69+
print(net.vars(level=2))
70+
self.assertTrue(len(net.vars(level=0)) == 0)
71+
self.assertTrue(len(net.vars(level=0, include_self=False)) == 0)
72+
self.assertTrue(len(net.vars(level=1)) == 2*2)
73+
self.assertTrue(len(net.vars(level=1, include_self=False)) == 2*2)
74+
self.assertTrue(len(net.vars(level=2)) == (2 + 4) * 2)
75+
self.assertTrue(len(net.vars(level=2, include_self=False)) == (2 + 4) * 2)
76+
self.assertTrue(len(net.vars(level=3)) == (2 + 4 + 8) * 2)
77+
self.assertTrue(len(net.vars(level=3, include_self=False)) == (2 + 4 + 8) * 2)
78+
79+
80+
81+
82+

brainpy/base/tests/test_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
2828
# variables
2929
self.t_last_pre_spike = bp.math.ones(self.size) * -1e7
3030
self.s = bp.math.zeros(self.size)
31-
self.g = self.register_constant_delay('g', size=self.size, delay=delay)
31+
self.g = bp.dyn.ConstantDelay(size=self.size, delay=delay)
3232

3333
@bp.odeint
3434
def int_s(self, s, t, TT):
@@ -227,7 +227,7 @@ def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
227227
# variables
228228
self.t_last_pre_spike = bp.math.Variable(bp.math.ones(self.size) * -1e7)
229229
self.s = bp.math.Variable(bp.math.zeros(self.size))
230-
self.g = self.register_constant_delay('g', size=self.size, delay=delay)
230+
self.g = bp.dyn.ConstantDelay(size=self.size, delay=delay)
231231

232232
@bp.odeint
233233
def int_s(self, s, t, TT):

brainpy/dyn/base.py

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -71,82 +71,6 @@ def ints(self, method='absolute'):
7171
gather[f'{node_path}.{k}' if node_path else k] = v
7272
return gather
7373

74-
def child_ds(self, method='absolute', include_self=False):
75-
"""Return the children instance of dynamical systems.
76-
77-
This is a shortcut function to get all children dynamical system
78-
in this object. For example:
79-
80-
>>> import brainpy as bp
81-
>>>
82-
>>> class Net(bp.DynamicalSystem):
83-
>>> def __init__(self, **kwargs):
84-
>>> super(Net, self).__init__(**kwargs)
85-
>>> self.A = bp.NeuGroup(10)
86-
>>> self.B = bp.NeuGroup(20)
87-
>>>
88-
>>> def update(self, _t, _dt):
89-
>>> for node in self.child_ds().values():
90-
>>> node.update(_t, _dt)
91-
>>>
92-
>>> net = Net()
93-
>>> net.child_ds()
94-
{'NeuGroup0': <brainpy.simulation.brainobjects.neuron.NeuGroup object at 0x000001ABD4FF02B0>,
95-
'NeuGroup1': <brainpy.simulation.brainobjects.neuron.NeuGroup object at 0x000001ABD74E5670>}
96-
97-
Parameters
98-
----------
99-
method : str
100-
The method to access the children nodes.
101-
include_self : bool
102-
Whether include the self dynamical system.
103-
104-
Returns
105-
-------
106-
collector: Collector
107-
A Collector includes all children systems.
108-
"""
109-
nodes = self.nodes(method=method).subset(DynamicalSystem).unique()
110-
if not include_self:
111-
if method == 'absolute':
112-
nodes.pop(self.name)
113-
elif method == 'relative':
114-
nodes.pop('')
115-
else:
116-
raise ValueError(f'Unknown access method: {method}')
117-
return nodes
118-
119-
def register_constant_delay(self, key, size, delay, dtype=None):
120-
"""Register a constant delay, whose update method will be appended into
121-
the ``self.steps`` in this host class.
122-
123-
Parameters
124-
----------
125-
key : str
126-
The delay name.
127-
size : int, list of int, tuple of int
128-
The delay data size.
129-
delay : int, float, ndarray
130-
The delay time, with the unit same with `brainpy.math.get_dt()`.
131-
dtype : optional
132-
The data type.
133-
134-
Returns
135-
-------
136-
delay : ConstantDelay
137-
An instance of ConstantDelay.
138-
"""
139-
if not hasattr(self, 'steps'):
140-
raise ModelBuildError('Please initialize the super class first before '
141-
'registering constant_delay. \n\n'
142-
'super(YourClassName, self).__init__(**kwargs)')
143-
if not key.isidentifier(): raise ValueError(f'{key} is not a valid identifier.')
144-
cdelay = ConstantDelay(size=size,
145-
delay=delay,
146-
name=f'{self.name}_delay_{key}',
147-
dtype=dtype)
148-
return cdelay
149-
15074
def __call__(self, *args, **kwargs):
15175
"""The shortcut to call ``update`` methods."""
15276
return self.update(*args, **kwargs)
@@ -202,7 +126,8 @@ def update(self, _t, _dt):
202126
In this update function, the update functions in children systems are
203127
iteratively called.
204128
"""
205-
for node in self.child_ds().values():
129+
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
130+
for node in nodes.values():
206131
node.update(_t, _dt)
207132

208133
def __getattr__(self, item):

0 commit comments

Comments
 (0)