@@ -105,7 +105,7 @@ def remove_from(*dicts_or_sets) -> None:
105105 self .register_parameter (name , value )
106106 else :
107107 modules = self .__dict__ .get ("_modules" )
108- if isinstance (value , (torch .nn .Module )):
108+ if isinstance (value , (torch .nn .Module , InfiniCoreModule )):
109109 if modules is None :
110110 raise AttributeError (
111111 "cannot assign module before Module.__init__() call"
@@ -181,6 +181,33 @@ def register_buffer(self, name: str, tensor: Optional[torch.tensor], persistent:
181181 self ._non_persistent_buffers_set .add (name )
182182
183183
184+ def add_module (self , name : str , module : Optional [torch .nn .Module ]) -> None :
185+ r"""Add a child module to the current module.
186+
187+ The module can be accessed as an attribute using the given name.
188+
189+ Args:
190+ name (str): name of the child module. The child module can be
191+ accessed from this module using the given name
192+ module (Module or None): child module to be added to the module. If
193+ ``None``, then operations that run on modules, such as :attr:`eval`,
194+ are ignored. If ``None``, the module is **not** included in the
195+ module's :attr:`children`.
196+ """
197+ if not isinstance (name , str ):
198+ raise TypeError (f"module name should be a string. Got { torch .typename (name )} " )
199+ elif '.' in name :
200+ raise KeyError (f"module name can't contain \" .\" , got: { name } " )
201+ elif name == '' :
202+ raise KeyError ("module name can't be empty string \" \" " )
203+ elif hasattr (self , name ) and name not in self ._modules :
204+ raise KeyError (f"attribute '{ name } ' already exists" )
205+
206+ if module is not None and not isinstance (module , (torch .nn .Module , InfiniCoreModule )):
207+ raise TypeError (f"{ torch .typename (module )} is not a Module subclass" )
208+
209+ self ._modules [name ] = module
210+
184211 def register_parameter (self , name : str , param : Optional [torch .nn .Parameter ]) -> None :
185212 r"""Add a parameter to the module.
186213
@@ -526,16 +553,212 @@ def load(module, local_state_dict, prefix=''):
526553 self .__class__ .__name__ , "\n \t " .join (error_msgs )))
527554 return _IncompatibleKeys (missing_keys , unexpected_keys )
528555
529- def children (self ) -> Iterator ['InfiniCoreModule' ]:
556+ def parameters (self , recurse : bool = True ) -> Iterator [torch .nn .Parameter ]:
557+ r"""Returns an iterator over module parameters.
558+
559+ Args:
560+ recurse (bool): if True, then yields parameters of this module
561+ and all submodules. Otherwise, yields only parameters that
562+ are direct members of this module.
563+
564+ Yields:
565+ Parameter: module parameter
566+
567+ Example::
568+
569+ >>> # xdoctest: +SKIP("undefined vars")
570+ >>> for param in model.parameters():
571+ ... print(type(param), param.size())
572+
573+ """
574+ for name , param in self .named_parameters (recurse = recurse ):
575+ yield param
576+
577+ def named_parameters (self , prefix : str = '' , recurse : bool = True ) -> Iterator [Tuple [str , torch .nn .Parameter ]]:
578+ r"""Returns an iterator over module parameters, yielding both the
579+ name of the parameter as well as the parameter itself.
580+
581+ Args:
582+ prefix (str): prefix to prepend to all parameter names.
583+ recurse (bool): if True, then yields parameters of this module
584+ and all submodules. Otherwise, yields only parameters that
585+ are direct members of this module.
586+
587+ Yields:
588+ (str, Parameter): Tuple containing the name and parameter
589+
590+ Example::
591+
592+ >>> # xdoctest: +SKIP("undefined vars")
593+ >>> for name, param in self.named_parameters():
594+ ... if name in ['bias']:
595+ ... print(param.size())
596+
597+ """
598+ gen = self ._named_members (
599+ lambda module : module ._parameters .items (),
600+ prefix = prefix , recurse = recurse )
601+ for elem in gen :
602+ yield elem
603+
604+ def buffers (self , recurse : bool = True ) -> Iterator [torch .Tensor ]:
605+ r"""Returns an iterator over module buffers.
606+
607+ Args:
608+ recurse (bool): if True, then yields buffers of this module
609+ and all submodules. Otherwise, yields only buffers that
610+ are direct members of this module.
611+
612+ Yields:
613+ torch.Tensor: module buffer
614+
615+ Example::
616+
617+ >>> # xdoctest: +SKIP("undefined vars")
618+ >>> for buf in model.buffers():
619+ ... print(type(buf), buf.size())
620+
621+ """
622+ for name , buf in self .named_buffers (recurse = recurse ):
623+ yield buf
624+
625+ def named_buffers (self , prefix : str = '' , recurse : bool = True ) -> Iterator [Tuple [str , torch .Tensor ]]:
626+ r"""Returns an iterator over module buffers, yielding both the
627+ name of the buffer as well as the buffer itself.
628+
629+ Args:
630+ prefix (str): prefix to prepend to all buffer names.
631+ recurse (bool): if True, then yields buffers of this module
632+ and all submodules. Otherwise, yields only buffers that
633+ are direct members of this module.
634+
635+ Yields:
636+ (str, torch.Tensor): Tuple containing the name and buffer
637+
638+ Example::
639+
640+ >>> # xdoctest: +SKIP("undefined vars")
641+ >>> for name, buf in self.named_buffers():
642+ ... if name in ['running_mean']:
643+ ... print(buf.size())
644+
645+ """
646+ memo = set ()
647+ modules = self .named_modules (prefix = prefix ) if recurse else [(prefix , self )]
648+ for module_prefix , module in modules :
649+ for k , v in module ._buffers .items ():
650+ if v is None or v in memo :
651+ continue
652+ if k in module ._non_persistent_buffers_set :
653+ continue
654+ memo .add (v )
655+ name = module_prefix + ('.' if module_prefix else '' ) + k
656+ yield (name , v )
657+
658+ def _named_members (self , get_members_fn , prefix = '' , recurse = True ):
659+ r"""Helper method to yield members with their names."""
660+ memo = set ()
661+ modules = self .named_modules (prefix = prefix ) if recurse else [(prefix , self )]
662+ for module_prefix , module in modules :
663+ members = get_members_fn (module )
664+ for k , v in members :
665+ if v is None or v in memo :
666+ continue
667+ memo .add (v )
668+ name = module_prefix + ('.' if module_prefix else '' ) + k
669+ yield (name , v )
670+
671+ def modules (self ) -> Iterator ['InfiniCoreModule' ]:
672+ r"""Returns an iterator over all modules in the network.
673+
674+ Yields:
675+ Module: a module in the network
676+
677+ Note:
678+ Duplicate modules are returned only once. In the following
679+ example, ``l`` will be returned only once.
680+
681+ Example::
682+
683+ >>> # xdoctest: +SKIP("undefined vars")
684+ >>> l = nn.Linear(2, 2)
685+ >>> net = nn.Sequential(l, l)
686+ >>> for idx, m in enumerate(net.modules()):
687+ ... print(idx, '->', m)
688+
689+ 0 -> Sequential(
690+ (0): Linear(in_features=2, out_features=2, bias=True)
691+ (1): Linear(in_features=2, out_features=2, bias=True)
692+ )
693+ 1 -> Linear(in_features=2, out_features=2, bias=True)
694+
695+ """
696+ for name , module in self .named_modules ():
697+ yield module
698+
699+ def named_modules (self , memo : Optional [Set ['InfiniCoreModule' ]] = None , prefix : str = '' , remove_duplicate : bool = True ):
700+ r"""Returns an iterator over all modules in the network, yielding
701+ both the name of the module as well as the module itself.
702+
703+ Args:
704+ memo: a memo to store the set of modules already added to the result
705+ prefix: a prefix that will be added to the name of the module
706+ remove_duplicate: whether to remove the duplicated module instances in the result
707+ or not
708+
709+ Yields:
710+ (str, Module): Tuple of name and module
711+
712+ Note:
713+ Duplicate modules are returned only once. In the following
714+ example, ``l`` will be returned only once.
715+
716+ Example::
717+
718+ >>> # xdoctest: +SKIP("undefined vars")
719+ >>> l = nn.Linear(2, 2)
720+ >>> net = nn.Sequential(l, l)
721+ >>> for idx, m in enumerate(net.named_modules()):
722+ ... print(idx, '->', m)
723+
724+ 0 -> ('', Sequential(
725+ (0): Linear(in_features=2, out_features=2, bias=True)
726+ (1): Linear(in_features=2, out_features=2, bias=True)
727+ ))
728+ 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
729+
730+ """
731+ if memo is None :
732+ memo = set ()
733+ if remove_duplicate :
734+ if self in memo :
735+ return
736+ memo .add (self )
737+ yield prefix , self
738+ for name , module in self ._modules .items ():
739+ if module is None :
740+ continue
741+ submodule_prefix = prefix + ('.' if prefix else '' ) + name
742+ # Handle both InfiniCoreModule and torch.nn.Module
743+ if isinstance (module , InfiniCoreModule ):
744+ for m in module .named_modules (memo , submodule_prefix , remove_duplicate ):
745+ yield m
746+ elif isinstance (module , torch .nn .Module ):
747+ # For torch.nn.Module, use its named_modules method
748+ # torch.nn.Module.named_modules returns (name, module) tuples
749+ for sub_name , sub_module in module .named_modules (prefix = submodule_prefix , remove_duplicate = remove_duplicate ):
750+ yield (sub_name , sub_module )
751+
752+ def children (self ) -> Iterator [Union ['InfiniCoreModule' , torch .nn .Module ]]:
530753 r"""Returns an iterator over immediate children modules.
531754
532755 Yields:
533- Module: a child module
756+ Module: a child module (can be InfiniCoreModule or torch.nn.Module)
534757 """
535758 for name , module in self .named_children ():
536759 yield module
537760
538- def named_children (self ) -> Iterator [Tuple [str , 'InfiniCoreModule' ]]:
761+ def named_children (self ) -> Iterator [Tuple [str , Union [ 'InfiniCoreModule' , torch . nn . Module ] ]]:
539762 r"""Returns an iterator over immediate children modules, yielding both
540763 the name of the module as well as the module itself.
541764
0 commit comments