Skip to content

Commit 160fd18

Browse files
zhuyuegongchensu
authored andcommitted
feat: add infinicore.nn.InfiniCoreModuleList referencing torch.nn.ModuleList.
add some functions in InfiniCoreModule.
1 parent 4305533 commit 160fd18

File tree

5 files changed

+732
-5
lines changed

5 files changed

+732
-5
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .module import InfiniCoreModule as Module
2+
from .module_list import InfiniCoreModuleList as ModuleList

python/infinicore/nn/modules/module.py

Lines changed: 227 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def remove_from(*dicts_or_sets) -> None:
105105
self.register_parameter(name, value)
106106
else:
107107
modules = self.__dict__.get("_modules")
108-
if isinstance(value, (torch.nn.Module)):
108+
if isinstance(value, (torch.nn.Module, InfiniCoreModule)):
109109
if modules is None:
110110
raise AttributeError(
111111
"cannot assign module before Module.__init__() call"
@@ -181,6 +181,33 @@ def register_buffer(self, name: str, tensor: Optional[torch.tensor], persistent:
181181
self._non_persistent_buffers_set.add(name)
182182

183183

184+
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
185+
r"""Add a child module to the current module.
186+
187+
The module can be accessed as an attribute using the given name.
188+
189+
Args:
190+
name (str): name of the child module. The child module can be
191+
accessed from this module using the given name
192+
module (Module or None): child module to be added to the module. If
193+
``None``, then operations that run on modules, such as :attr:`eval`,
194+
are ignored. If ``None``, the module is **not** included in the
195+
module's :attr:`children`.
196+
"""
197+
if not isinstance(name, str):
198+
raise TypeError(f"module name should be a string. Got {torch.typename(name)}")
199+
elif '.' in name:
200+
raise KeyError(f"module name can't contain \".\", got: {name}")
201+
elif name == '':
202+
raise KeyError("module name can't be empty string \"\"")
203+
elif hasattr(self, name) and name not in self._modules:
204+
raise KeyError(f"attribute '{name}' already exists")
205+
206+
if module is not None and not isinstance(module, (torch.nn.Module, InfiniCoreModule)):
207+
raise TypeError(f"{torch.typename(module)} is not a Module subclass")
208+
209+
self._modules[name] = module
210+
184211
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
185212
r"""Add a parameter to the module.
186213
@@ -526,16 +553,212 @@ def load(module, local_state_dict, prefix=''):
526553
self.__class__.__name__, "\n\t".join(error_msgs)))
527554
return _IncompatibleKeys(missing_keys, unexpected_keys)
528555

529-
def children(self) -> Iterator['InfiniCoreModule']:
556+
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
557+
r"""Returns an iterator over module parameters.
558+
559+
Args:
560+
recurse (bool): if True, then yields parameters of this module
561+
and all submodules. Otherwise, yields only parameters that
562+
are direct members of this module.
563+
564+
Yields:
565+
Parameter: module parameter
566+
567+
Example::
568+
569+
>>> # xdoctest: +SKIP("undefined vars")
570+
>>> for param in model.parameters():
571+
... print(type(param), param.size())
572+
573+
"""
574+
for name, param in self.named_parameters(recurse=recurse):
575+
yield param
576+
577+
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
578+
r"""Returns an iterator over module parameters, yielding both the
579+
name of the parameter as well as the parameter itself.
580+
581+
Args:
582+
prefix (str): prefix to prepend to all parameter names.
583+
recurse (bool): if True, then yields parameters of this module
584+
and all submodules. Otherwise, yields only parameters that
585+
are direct members of this module.
586+
587+
Yields:
588+
(str, Parameter): Tuple containing the name and parameter
589+
590+
Example::
591+
592+
>>> # xdoctest: +SKIP("undefined vars")
593+
>>> for name, param in self.named_parameters():
594+
... if name in ['bias']:
595+
... print(param.size())
596+
597+
"""
598+
gen = self._named_members(
599+
lambda module: module._parameters.items(),
600+
prefix=prefix, recurse=recurse)
601+
for elem in gen:
602+
yield elem
603+
604+
def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]:
605+
r"""Returns an iterator over module buffers.
606+
607+
Args:
608+
recurse (bool): if True, then yields buffers of this module
609+
and all submodules. Otherwise, yields only buffers that
610+
are direct members of this module.
611+
612+
Yields:
613+
torch.Tensor: module buffer
614+
615+
Example::
616+
617+
>>> # xdoctest: +SKIP("undefined vars")
618+
>>> for buf in model.buffers():
619+
... print(type(buf), buf.size())
620+
621+
"""
622+
for name, buf in self.named_buffers(recurse=recurse):
623+
yield buf
624+
625+
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]:
626+
r"""Returns an iterator over module buffers, yielding both the
627+
name of the buffer as well as the buffer itself.
628+
629+
Args:
630+
prefix (str): prefix to prepend to all buffer names.
631+
recurse (bool): if True, then yields buffers of this module
632+
and all submodules. Otherwise, yields only buffers that
633+
are direct members of this module.
634+
635+
Yields:
636+
(str, torch.Tensor): Tuple containing the name and buffer
637+
638+
Example::
639+
640+
>>> # xdoctest: +SKIP("undefined vars")
641+
>>> for name, buf in self.named_buffers():
642+
... if name in ['running_mean']:
643+
... print(buf.size())
644+
645+
"""
646+
memo = set()
647+
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
648+
for module_prefix, module in modules:
649+
for k, v in module._buffers.items():
650+
if v is None or v in memo:
651+
continue
652+
if k in module._non_persistent_buffers_set:
653+
continue
654+
memo.add(v)
655+
name = module_prefix + ('.' if module_prefix else '') + k
656+
yield (name, v)
657+
658+
def _named_members(self, get_members_fn, prefix='', recurse=True):
659+
r"""Helper method to yield members with their names."""
660+
memo = set()
661+
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
662+
for module_prefix, module in modules:
663+
members = get_members_fn(module)
664+
for k, v in members:
665+
if v is None or v in memo:
666+
continue
667+
memo.add(v)
668+
name = module_prefix + ('.' if module_prefix else '') + k
669+
yield (name, v)
670+
671+
def modules(self) -> Iterator['InfiniCoreModule']:
672+
r"""Returns an iterator over all modules in the network.
673+
674+
Yields:
675+
Module: a module in the network
676+
677+
Note:
678+
Duplicate modules are returned only once. In the following
679+
example, ``l`` will be returned only once.
680+
681+
Example::
682+
683+
>>> # xdoctest: +SKIP("undefined vars")
684+
>>> l = nn.Linear(2, 2)
685+
>>> net = nn.Sequential(l, l)
686+
>>> for idx, m in enumerate(net.modules()):
687+
... print(idx, '->', m)
688+
689+
0 -> Sequential(
690+
(0): Linear(in_features=2, out_features=2, bias=True)
691+
(1): Linear(in_features=2, out_features=2, bias=True)
692+
)
693+
1 -> Linear(in_features=2, out_features=2, bias=True)
694+
695+
"""
696+
for name, module in self.named_modules():
697+
yield module
698+
699+
def named_modules(self, memo: Optional[Set['InfiniCoreModule']] = None, prefix: str = '', remove_duplicate: bool = True):
700+
r"""Returns an iterator over all modules in the network, yielding
701+
both the name of the module as well as the module itself.
702+
703+
Args:
704+
memo: a memo to store the set of modules already added to the result
705+
prefix: a prefix that will be added to the name of the module
706+
remove_duplicate: whether to remove the duplicated module instances in the result
707+
or not
708+
709+
Yields:
710+
(str, Module): Tuple of name and module
711+
712+
Note:
713+
Duplicate modules are returned only once. In the following
714+
example, ``l`` will be returned only once.
715+
716+
Example::
717+
718+
>>> # xdoctest: +SKIP("undefined vars")
719+
>>> l = nn.Linear(2, 2)
720+
>>> net = nn.Sequential(l, l)
721+
>>> for idx, m in enumerate(net.named_modules()):
722+
... print(idx, '->', m)
723+
724+
0 -> ('', Sequential(
725+
(0): Linear(in_features=2, out_features=2, bias=True)
726+
(1): Linear(in_features=2, out_features=2, bias=True)
727+
))
728+
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
729+
730+
"""
731+
if memo is None:
732+
memo = set()
733+
if remove_duplicate:
734+
if self in memo:
735+
return
736+
memo.add(self)
737+
yield prefix, self
738+
for name, module in self._modules.items():
739+
if module is None:
740+
continue
741+
submodule_prefix = prefix + ('.' if prefix else '') + name
742+
# Handle both InfiniCoreModule and torch.nn.Module
743+
if isinstance(module, InfiniCoreModule):
744+
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
745+
yield m
746+
elif isinstance(module, torch.nn.Module):
747+
# For torch.nn.Module, use its named_modules method
748+
# torch.nn.Module.named_modules returns (name, module) tuples
749+
for sub_name, sub_module in module.named_modules(prefix=submodule_prefix, remove_duplicate=remove_duplicate):
750+
yield (sub_name, sub_module)
751+
752+
def children(self) -> Iterator[Union['InfiniCoreModule', torch.nn.Module]]:
530753
r"""Returns an iterator over immediate children modules.
531754
532755
Yields:
533-
Module: a child module
756+
Module: a child module (can be InfiniCoreModule or torch.nn.Module)
534757
"""
535758
for name, module in self.named_children():
536759
yield module
537760

538-
def named_children(self) -> Iterator[Tuple[str, 'InfiniCoreModule']]:
761+
def named_children(self) -> Iterator[Tuple[str, Union['InfiniCoreModule', torch.nn.Module]]]:
539762
r"""Returns an iterator over immediate children modules, yielding both
540763
the name of the module as well as the module itself.
541764

0 commit comments

Comments
 (0)