1+ # ============================================
12# Copyright (c) 2025, InfiniCore
2- #
3+ #
34# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
45# but based on InfiniCoreModule for inference purposes.
56
6- from typing import List , Optional , Iterator , Union , Sequence , TypeVar
7- import torch
87import operator
9- from itertools import chain
108from collections import OrderedDict
11- from .module import InfiniCoreModule
9+ from itertools import chain
10+ from typing import Iterator , List , Optional , Sequence , TypeVar , Union
1211
13- # Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
14- ModuleType = TypeVar ('ModuleType' , bound = Union [torch .nn .Module , 'InfiniCoreModule' ])
12+ from .module import InfiniCoreModule as Module
1513
14+ # Define type variable for module compatibility (supports InfiniCoreModule)
15+ ModuleType = TypeVar ("ModuleType" , bound = Union ["Module" ])
1616
17- class InfiniCoreModuleList (InfiniCoreModule ):
17+
18+ class InfiniCoreModuleList (Module ):
1819 r"""Holds submodules in a list.
1920
2021 InfiniCoreModuleList can be indexed like a regular Python list, but
@@ -54,7 +55,9 @@ def _get_abs_string_index(self, idx):
5455 idx += len (self )
5556 return str (idx )
5657
57- def __getitem__ (self , idx : Union [int , slice ]) -> Union [ModuleType , 'InfiniCoreModuleList' ]:
58+ def __getitem__ (
59+ self , idx : Union [int , slice ]
60+ ) -> Union [ModuleType , "InfiniCoreModuleList" ]:
5861 if isinstance (idx , slice ):
5962 return self .__class__ (list (self ._modules .values ())[idx ])
6063 else :
@@ -75,7 +78,7 @@ def __delitem__(self, idx: Union[int, slice]) -> None:
7578 idx_str = self ._get_abs_string_index (idx )
7679 if idx_str in self ._modules :
7780 del self ._modules [idx_str ]
78-
81+
7982 # To preserve numbering, self._modules is being reconstructed with modules after deletion
8083 if len (self ._modules ) > 0 :
8184 str_indices = [str (i ) for i in range (len (self ._modules ))]
@@ -87,10 +90,12 @@ def __len__(self) -> int:
8790 def __iter__ (self ) -> Iterator [ModuleType ]:
8891 return iter (self ._modules .values ())
8992
90- def __iadd__ (self , modules : Sequence [ModuleType ]) -> ' InfiniCoreModuleList' :
93+ def __iadd__ (self , modules : Sequence [ModuleType ]) -> " InfiniCoreModuleList" :
9194 return self .extend (modules )
9295
93- def __add__ (self , other : Union [Sequence [ModuleType ], 'InfiniCoreModuleList' ]) -> 'InfiniCoreModuleList' :
96+ def __add__ (
97+ self , other : Union [Sequence [ModuleType ], "InfiniCoreModuleList" ]
98+ ) -> "InfiniCoreModuleList" :
9499 r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
95100
96101 Args:
@@ -101,22 +106,22 @@ def __add__(self, other: Union[Sequence[ModuleType], 'InfiniCoreModuleList']) ->
101106 f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
102107 f"got { type (other ).__name__ } "
103108 )
104-
109+
105110 combined = InfiniCoreModuleList ()
106111 for i , module in enumerate (chain (self , other )):
107112 combined .add_module (str (i ), module )
108113 return combined
109114
110- def append (self , module : ModuleType ) -> ' InfiniCoreModuleList' :
115+ def append (self , module : ModuleType ) -> " InfiniCoreModuleList" :
111116 r"""Append a given module to the end of the list.
112117
113118 Args:
114- module (nn.Module or InfiniCoreModule): module to append
119+ module (InfiniCoreModule): module to append
115120 """
116121 self .add_module (str (len (self )), module )
117122 return self
118123
119- def extend (self , modules : Sequence [ModuleType ]) -> ' InfiniCoreModuleList' :
124+ def extend (self , modules : Sequence [ModuleType ]) -> " InfiniCoreModuleList" :
120125 r"""Append modules from a Python iterable to the end of the list.
121126
122127 Args:
@@ -130,7 +135,7 @@ def extend(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
130135 f"InfiniCoreModuleList.extend should be called with an "
131136 f"iterable, but got { type (modules ).__name__ } "
132137 )
133-
138+
134139 offset = len (self )
135140 for i , module in enumerate (modules ):
136141 self .add_module (str (offset + i ), module )
@@ -141,7 +146,7 @@ def insert(self, index: int, module: ModuleType) -> None:
141146
142147 Args:
143148 index (int): index to insert.
144- module (nn.Module or InfiniCoreModule): module to insert
149+ module ( InfiniCoreModule): module to insert
145150 """
146151 for i in range (len (self ._modules ), index , - 1 ):
147152 self ._modules [str (i )] = self ._modules [str (i - 1 )]
@@ -166,11 +171,11 @@ def __repr__(self) -> str:
166171 """Return a string representation of the ModuleList."""
167172 if len (self ) == 0 :
168173 return self .__class__ .__name__ + "()"
169-
174+
170175 lines = []
171176 for i , module in enumerate (self ):
172177 lines .append (f"({ i } ): { repr (module )} " )
173-
178+
174179 main_str = self .__class__ .__name__ + "(\n "
175180 main_str += "\n " .join (lines ) + "\n )"
176181 return main_str
0 commit comments