Skip to content

Commit 10853b8

Browse files
committed
updates
1 parent b3e357f commit 10853b8

File tree

5 files changed

+29
-44
lines changed

5 files changed

+29
-44
lines changed

brainpy/_src/dynsys.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ class Network(DynSysGroup):
411411
pass
412412

413413

414-
class Sequential(DynamicalSystem, AutoDelaySupp):
414+
class Sequential(DynamicalSystem, AutoDelaySupp, Container):
415415
"""A sequential `input-output` module.
416416
417417
Modules will be added to it in the order they are passed in the
@@ -468,22 +468,12 @@ def __init__(
468468
**modules_as_dict
469469
):
470470
super().__init__(name=name, mode=mode)
471-
self._dyn_modules = bm.NodeDict()
472-
self._static_modules = dict()
473-
i = 0
474-
for m in modules_as_tuple + tuple(modules_as_dict.values()):
475-
key = self.__format_key(i)
476-
if isinstance(m, bm.BrainPyObject):
477-
self._dyn_modules[key] = m
478-
else:
479-
self._static_modules[key] = m
480-
i += 1
481-
self._num = i
471+
self.children = bm.node_dict(self.format_elements(object, *modules_as_tuple, **modules_as_dict))
482472

483473
def update(self, x):
484474
"""Update function of a sequential model.
485475
"""
486-
for m in self.__all_nodes():
476+
for m in self.children.values():
487477
x = m(x)
488478
return x
489479

@@ -494,15 +484,6 @@ def return_info(self):
494484
f'not instance of {AutoDelaySupp.__name__}')
495485
return last.return_info()
496486

497-
def append(self, module: Callable):
498-
assert isinstance(module, Callable)
499-
key = self.__format_key(self._num)
500-
if isinstance(module, bm.BrainPyObject):
501-
self._dyn_modules[key] = module
502-
else:
503-
self._static_modules[key] = module
504-
self._num += 1
505-
506487
def __format_key(self, i):
507488
return f'l-{i}'
508489

@@ -518,19 +499,17 @@ def __all_nodes(self):
518499

519500
def __getitem__(self, key: Union[int, slice, str]):
520501
if isinstance(key, str):
521-
if key in self._dyn_modules:
522-
return self._dyn_modules[key]
523-
elif key in self._static_modules:
524-
return self._static_modules[key]
502+
if key in self.children:
503+
return self.children[key]
525504
else:
526505
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
527506
elif isinstance(key, slice):
528-
return Sequential(*(self.__all_nodes()[key]))
507+
return Sequential(**dict(tuple(self.children.items())[key]))
529508
elif isinstance(key, int):
530-
return self.__all_nodes()[key]
509+
return tuple(self.children.values())[key]
531510
elif isinstance(key, (tuple, list)):
532-
_all_nodes = self.__all_nodes()
533-
return Sequential(*[_all_nodes[k] for k in key])
511+
_all_nodes = tuple(self.children.items())
512+
return Sequential(**dict(_all_nodes[k] for k in key))
534513
else:
535514
raise KeyError(f'Unknown type of key: {type(key)}')
536515

@@ -653,7 +632,7 @@ def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None):
653632
batch_axis_name=bm.sharding.BATCH_AXIS)
654633

655634
def __repr__(self):
656-
return f'{self.__class__.__name__}(name={self.name}, mode={self.mode}, size={self.size})'
635+
return f'{self.name}(mode={self.mode}, size={self.size})'
657636

658637
def __getitem__(self, item):
659638
return DynView(target=self, index=item)

brainpy/_src/math/object_transform/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,8 @@ def __init__(self, seq=()):
649649
self.extend(seq)
650650

651651
def append(self, element) -> 'NodeList':
652-
if not isinstance(element, BrainPyObject):
653-
raise TypeError(f'element must be an instance of {BrainPyObject.__name__}.')
652+
# if not isinstance(element, BrainPyObject):
653+
# raise TypeError(f'element must be an instance of {BrainPyObject.__name__}.')
654654
super().append(element)
655655
return self
656656

@@ -668,10 +668,10 @@ class NodeDict(dict):
668668
:py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`.
669669
"""
670670

671-
def _check_elem(self, elem):
672-
if not isinstance(elem, BrainPyObject):
673-
raise TypeError(f'Element should be {BrainPyObject.__name__}, but got {type(elem)}.')
674-
return elem
671+
# def _check_elem(self, elem):
672+
# if not isinstance(elem, BrainPyObject):
673+
# raise TypeError(f'Element should be {BrainPyObject.__name__}, but got {type(elem)}.')
674+
# return elem
675675

676676
def __init__(self, *args, **kwargs):
677677
super().__init__()
@@ -690,7 +690,7 @@ def update(self, *args, **kwargs) -> 'VarDict':
690690
return self
691691

692692
def __setitem__(self, key, value) -> 'VarDict':
693-
super().__setitem__(key, self._check_elem(value))
693+
super().__setitem__(key, value)
694694
return self
695695

696696

brainpy/_src/math/object_transform/naming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def check_name_uniqueness(name, obj):
3232
_name2id[name] = id(obj)
3333

3434

35-
def get_unique_name(type_):
35+
def get_unique_name(type_: str):
3636
"""Get the unique name for the given object type."""
3737
if type_ not in _typed_names:
3838
_typed_names[type_] = 0

brainpy/_src/mixin.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010

1111
from brainpy import math as bm, tools
12+
from brainpy._src.math.object_transform.naming import get_unique_name
1213
from brainpy._src.initialize import parameter
1314
from brainpy.types import ArrayType
1415

@@ -192,19 +193,25 @@ def __repr__(self):
192193
string = ", \n".join(child_str)
193194
return f'{cls_name}({string})'
194195

196+
def __get_elem_name(self, elem):
197+
if isinstance(elem, bm.BrainPyObject):
198+
return elem.name
199+
else:
200+
return get_unique_name('ContainerElem')
201+
195202
def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict):
196203
res = dict()
197204

198205
# add tuple-typed components
199206
for module in children_as_tuple:
200207
if isinstance(module, child_type):
201-
res[module.name] = module
208+
res[self.__get_elem_name(module)] = module
202209
elif isinstance(module, (list, tuple)):
203210
for m in module:
204211
if not isinstance(m, child_type):
205212
raise ValueError(f'Should be instance of {child_type.__name__}. '
206213
f'But we got {type(m)}')
207-
res[m.name] = m
214+
res[self.__get_elem_name(m)] = m
208215
elif isinstance(module, dict):
209216
for k, v in module.items():
210217
if not isinstance(v, child_type):
@@ -226,12 +233,12 @@ def add_elem(self, **elements):
226233
"""Add new elements.
227234
228235
>>> obj = Container()
229-
>>> obj.add_elem(1.)
236+
>>> obj.add_elem(a=1.)
230237
231238
Args:
232239
elements: children objects.
233240
"""
234-
self.check_hierarchies(type(self), **elements)
241+
# self.check_hierarchies(type(self), **elements)
235242
self.children.update(self.format_elements(object, **elements))
236243

237244

brainpy/_src/runners.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def __repr__(self):
379379
def reset_state(self):
380380
"""Reset state of the ``DSRunner``."""
381381
self.i0 = 0
382-
self.t0 = self.t0
383382

384383
def predict(
385384
self,

0 commit comments

Comments
 (0)