@@ -411,7 +411,7 @@ class Network(DynSysGroup):
411411 pass
412412
413413
414- class Sequential (DynamicalSystem , AutoDelaySupp ):
414+ class Sequential (DynamicalSystem , AutoDelaySupp , Container ):
415415 """A sequential `input-output` module.
416416
417417 Modules will be added to it in the order they are passed in the
@@ -468,22 +468,12 @@ def __init__(
468468 ** modules_as_dict
469469 ):
470470 super ().__init__ (name = name , mode = mode )
471- self ._dyn_modules = bm .NodeDict ()
472- self ._static_modules = dict ()
473- i = 0
474- for m in modules_as_tuple + tuple (modules_as_dict .values ()):
475- key = self .__format_key (i )
476- if isinstance (m , bm .BrainPyObject ):
477- self ._dyn_modules [key ] = m
478- else :
479- self ._static_modules [key ] = m
480- i += 1
481- self ._num = i
471+ self .children = bm .node_dict (self .format_elements (object , * modules_as_tuple , ** modules_as_dict ))
482472
483473 def update (self , x ):
484474 """Update function of a sequential model.
485475 """
486- for m in self .__all_nodes ():
476+ for m in self .children . values ():
487477 x = m (x )
488478 return x
489479
@@ -494,15 +484,6 @@ def return_info(self):
494484 f'not instance of { AutoDelaySupp .__name__ } ' )
495485 return last .return_info ()
496486
497- def append (self , module : Callable ):
498- assert isinstance (module , Callable )
499- key = self .__format_key (self ._num )
500- if isinstance (module , bm .BrainPyObject ):
501- self ._dyn_modules [key ] = module
502- else :
503- self ._static_modules [key ] = module
504- self ._num += 1
505-
506487 def __format_key (self , i ):
507488 return f'l-{ i } '
508489
@@ -518,19 +499,17 @@ def __all_nodes(self):
518499
519500 def __getitem__ (self , key : Union [int , slice , str ]):
520501 if isinstance (key , str ):
521- if key in self ._dyn_modules :
522- return self ._dyn_modules [key ]
523- elif key in self ._static_modules :
524- return self ._static_modules [key ]
502+ if key in self .children :
503+ return self .children [key ]
525504 else :
526505 raise KeyError (f'Does not find a component named { key } in\n { str (self )} ' )
527506 elif isinstance (key , slice ):
528- return Sequential (* ( self .__all_nodes ( )[key ]))
507+ return Sequential (** dict ( tuple ( self .children . items () )[key ]))
529508 elif isinstance (key , int ):
530- return self .__all_nodes ( )[key ]
509+ return tuple ( self .children . values () )[key ]
531510 elif isinstance (key , (tuple , list )):
532- _all_nodes = self .__all_nodes ( )
533- return Sequential (* [ _all_nodes [k ] for k in key ] )
511+ _all_nodes = tuple ( self .children . items () )
512+ return Sequential (** dict ( _all_nodes [k ] for k in key ) )
534513 else :
535514 raise KeyError (f'Unknown type of key: { type (key )} ' )
536515
@@ -653,7 +632,7 @@ def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None):
653632 batch_axis_name = bm .sharding .BATCH_AXIS )
654633
655634 def __repr__ (self ):
656- return f'{ self .__class__ . __name__ } ( name= { self . name } , mode={ self .mode } , size={ self .size } )'
635+ return f'{ self .name } ( mode={ self .mode } , size={ self .size } )'
657636
658637 def __getitem__ (self , item ):
659638 return DynView (target = self , index = item )
0 commit comments