@@ -59,17 +59,16 @@ class DynamicalSystem(BrainPyObject):
5959 The model computation mode. It should be instance of :py:class:`~.Mode`.
6060 """
6161
62- '''Online fitting method.'''
6362 online_fit_by : Optional [OnlineAlgorithm ]
63+ '''Online fitting method.'''
6464
65- '''Offline fitting method.'''
6665 offline_fit_by : Optional [OfflineAlgorithm ]
66+ '''Offline fitting method.'''
6767
68+ global_delay_data : Dict [str , Tuple [Union [bm .LengthDelay , None ], bm .Variable ]] = dict ()
6869 '''Global delay data, which stores the delay variables and corresponding delay targets.
69-
7070 This variable is useful when the same target variable is used in multiple mappings,
7171 as it can reduce the duplicate delay variable registration.'''
72- global_delay_data : Dict [str , Tuple [Union [bm .LengthDelay , None ], bm .Variable ]] = dict ()
7372
7473 def __init__ (
7574 self ,
@@ -435,15 +434,45 @@ def clear_input(self):
435434 node .clear_input ()
436435
437436
438- class Sequential (Container ):
437+ class Sequential (DynamicalSystem ):
439438 def __init__ (
440439 self ,
441440 * modules ,
442441 name : str = None ,
443442 mode : Mode = normal ,
444443 ** kw_modules
445444 ):
446- super (Sequential , self ).__init__ (* modules , name = name , mode = mode , ** kw_modules )
445+ super ().__init__ (name = name , mode = mode )
446+ self ._modules = tuple (modules ) + tuple (kw_modules .values ())
447+
448+ seq_modules = [m for m in modules if isinstance (m , BrainPyObject )]
449+ dict_modules = {k : m for k , m in kw_modules .items () if isinstance (m , BrainPyObject )}
450+
451+ # add tuple-typed components
452+ for module in seq_modules :
453+ if isinstance (module , BrainPyObject ):
454+ self .implicit_nodes [module .name ] = module
455+ elif isinstance (module , (list , tuple )):
456+ for m in module :
457+ if not isinstance (m , BrainPyObject ):
458+ raise ValueError (f'Should be instance of { BrainPyObject .__name__ } . '
459+ f'But we got { type (m )} ' )
460+ self .implicit_nodes [m .name ] = module
461+ elif isinstance (module , dict ):
462+ for k , v in module .items ():
463+ if not isinstance (v , BrainPyObject ):
464+ raise ValueError (f'Should be instance of { BrainPyObject .__name__ } . '
465+ f'But we got { type (v )} ' )
466+ self .implicit_nodes [k ] = v
467+ else :
468+ raise ValueError (f'Cannot parse sub-systems. They should be { BrainPyObject .__name__ } '
469+ f'or a list/tuple/dict of { BrainPyObject .__name__ } .' )
470+ # add dict-typed components
471+ for k , v in dict_modules .items ():
472+ if not isinstance (v , BrainPyObject ):
473+ raise ValueError (f'Should be instance of { BrainPyObject .__name__ } . '
474+ f'But we got { type (v )} ' )
475+ self .implicit_nodes [k ] = v
447476
448477 def __getattr__ (self , item ):
449478 """Wrap the dot access ('self.'). """
@@ -463,7 +492,7 @@ def __getitem__(self, key: Union[int, slice]):
463492 components = tuple (self .implicit_nodes .values ())[key ]
464493 return Sequential (dict (zip (keys , components )))
465494 elif isinstance (key , int ):
466- return self .implicit_nodes .values ()[key ]
495+ return tuple ( self .implicit_nodes .values () )[key ]
467496 elif isinstance (key , (tuple , list )):
468497 all_keys = tuple (self .implicit_nodes .keys ())
469498 all_vals = tuple (self .implicit_nodes .values ())
@@ -478,27 +507,7 @@ def __getitem__(self, key: Union[int, slice]):
478507 raise KeyError (f'Unknown type of key: { type (key )} ' )
479508
480509 def __repr__ (self ):
481- def f (x ):
482- if not isinstance (x , DynamicalSystem ) and callable (x ):
483- signature = inspect .signature (x )
484- args = [f'{ k } ={ v .default } ' for k , v in signature .parameters .items ()
485- if v .default is not inspect .Parameter .empty ]
486- args = ', ' .join (args )
487- while not hasattr (x , '__name__' ):
488- if not hasattr (x , 'func' ):
489- break
490- x = x .func # Handle functools.partial
491- if not hasattr (x , '__name__' ) and hasattr (x , '__class__' ):
492- return x .__class__ .__name__
493- if args :
494- return f'{ x .__name__ } (*, { args } )'
495- return x .__name__
496- else :
497- x = repr (x ).split ('\n ' )
498- x = [x [0 ]] + [' ' + y for y in x [1 :]]
499- return '\n ' .join (x )
500-
501- entries = '\n ' .join (f' [{ i } ] { f (x )} ' for i , x in enumerate (self ))
510+ entries = '\n ' .join (f' [{ i } ] { tools .repr_object (x )} ' for i , x in enumerate (self ._modules ))
502511 return f'{ self .__class__ .__name__ } (\n { entries } \n )'
503512
504513 def update (self , sha : dict , x : Any ) -> Array :
@@ -516,8 +525,11 @@ def update(self, sha: dict, x: Any) -> Array:
516525 y: Array
517526 The output tensor.
518527 """
519- for node in self .implicit_nodes .values ():
520- x = node (sha , x )
528+ for m in self ._modules :
529+ if isinstance (m , DynamicalSystem ):
530+ x = m (sha , x )
531+ else :
532+ x = m (x )
521533 return x
522534
523535
0 commit comments