|
| 1 | +# ============================================ |
| 2 | +# Copyright (c) 2025, InfiniCore |
| 3 | +# |
| 4 | +# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList |
| 5 | +# but based on InfiniCoreModule for inference purposes. |
| 6 | + |
| 7 | +import operator |
| 8 | +from collections import OrderedDict |
| 9 | +from itertools import chain |
| 10 | +from typing import Iterator, List, Optional, Sequence, TypeVar, Union |
| 11 | + |
| 12 | +from .module import InfiniCoreModule as Module |
| 13 | + |
| 14 | +# Define type variable for module compatibility (supports InfiniCoreModule) |
| 15 | +ModuleType = TypeVar("ModuleType", bound=Union["Module"]) |
| 16 | + |
| 17 | + |
| 18 | +class InfiniCoreModuleList(Module): |
| 19 | + r"""Holds submodules in a list. |
| 20 | +
|
| 21 | + InfiniCoreModuleList can be indexed like a regular Python list, but |
| 22 | + modules it contains are properly registered, and will be visible by all |
| 23 | + InfiniCoreModule methods. |
| 24 | +
|
| 25 | + Args: |
| 26 | + modules (iterable, optional): an iterable of modules to add |
| 27 | +
|
| 28 | + Example:: |
| 29 | +
|
| 30 | + >>> class MyModel(InfiniCoreModule): |
| 31 | + ... def __init__(self): |
| 32 | + ... super().__init__() |
| 33 | + ... self.linears = InfiniCoreModuleList([ |
| 34 | + ... torch.nn.Linear(10, 10) for i in range(10) |
| 35 | + ... ]) |
| 36 | + ... |
| 37 | + ... def forward(self, x): |
| 38 | + ... # ModuleList can act as an iterable, or be indexed using ints |
| 39 | + ... for i, l in enumerate(self.linears): |
| 40 | + ... x = self.linears[i // 2](x) + l(x) |
| 41 | + ... return x |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, modules: Optional[Sequence[ModuleType]] = None): |
| 45 | + super().__init__() |
| 46 | + if modules is not None: |
| 47 | + self += modules |
| 48 | + |
| 49 | + def _get_abs_string_index(self, idx): |
| 50 | + """Get the absolute index for the list of modules.""" |
| 51 | + idx = operator.index(idx) |
| 52 | + if not (-len(self) <= idx < len(self)): |
| 53 | + raise IndexError(f"index {idx} is out of range") |
| 54 | + if idx < 0: |
| 55 | + idx += len(self) |
| 56 | + return str(idx) |
| 57 | + |
| 58 | + def __getitem__( |
| 59 | + self, idx: Union[int, slice] |
| 60 | + ) -> Union[ModuleType, "InfiniCoreModuleList"]: |
| 61 | + if isinstance(idx, slice): |
| 62 | + return self.__class__(list(self._modules.values())[idx]) |
| 63 | + else: |
| 64 | + return self._modules[self._get_abs_string_index(idx)] |
| 65 | + |
| 66 | + def __setitem__(self, idx: int, module: ModuleType) -> None: |
| 67 | + idx = self._get_abs_string_index(idx) |
| 68 | + # Use add_module to register module |
| 69 | + self.add_module(idx, module) |
| 70 | + |
| 71 | + def __delitem__(self, idx: Union[int, slice]) -> None: |
| 72 | + if isinstance(idx, slice): |
| 73 | + indices_to_delete = list(range(len(self._modules)))[idx] |
| 74 | + for k in indices_to_delete: |
| 75 | + if str(k) in self._modules: |
| 76 | + del self._modules[str(k)] |
| 77 | + else: |
| 78 | + idx_str = self._get_abs_string_index(idx) |
| 79 | + if idx_str in self._modules: |
| 80 | + del self._modules[idx_str] |
| 81 | + |
| 82 | + # To preserve numbering, self._modules is being reconstructed with modules after deletion |
| 83 | + if len(self._modules) > 0: |
| 84 | + str_indices = [str(i) for i in range(len(self._modules))] |
| 85 | + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) |
| 86 | + |
| 87 | + def __len__(self) -> int: |
| 88 | + return len(self._modules) |
| 89 | + |
| 90 | + def __iter__(self) -> Iterator[ModuleType]: |
| 91 | + return iter(self._modules.values()) |
| 92 | + |
| 93 | + def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList": |
| 94 | + return self.extend(modules) |
| 95 | + |
| 96 | + def __add__( |
| 97 | + self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"] |
| 98 | + ) -> "InfiniCoreModuleList": |
| 99 | + r"""Return a new InfiniCoreModuleList by concatenating with another iterable. |
| 100 | +
|
| 101 | + Args: |
| 102 | + other (iterable): iterable of modules to concatenate |
| 103 | + """ |
| 104 | + if not isinstance(other, (list, tuple, InfiniCoreModuleList)): |
| 105 | + raise TypeError( |
| 106 | + f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, " |
| 107 | + f"got {type(other).__name__}" |
| 108 | + ) |
| 109 | + |
| 110 | + combined = InfiniCoreModuleList() |
| 111 | + for i, module in enumerate(chain(self, other)): |
| 112 | + combined.add_module(str(i), module) |
| 113 | + return combined |
| 114 | + |
| 115 | + def append(self, module: ModuleType) -> "InfiniCoreModuleList": |
| 116 | + r"""Append a given module to the end of the list. |
| 117 | +
|
| 118 | + Args: |
| 119 | + module (InfiniCoreModule): module to append |
| 120 | + """ |
| 121 | + self.add_module(str(len(self)), module) |
| 122 | + return self |
| 123 | + |
| 124 | + def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList": |
| 125 | + r"""Append modules from a Python iterable to the end of the list. |
| 126 | +
|
| 127 | + Args: |
| 128 | + modules (iterable): iterable of modules to append |
| 129 | + """ |
| 130 | + if not isinstance(modules, (list, tuple)): |
| 131 | + try: |
| 132 | + modules = list(modules) |
| 133 | + except TypeError: |
| 134 | + raise TypeError( |
| 135 | + f"InfiniCoreModuleList.extend should be called with an " |
| 136 | + f"iterable, but got {type(modules).__name__}" |
| 137 | + ) |
| 138 | + |
| 139 | + offset = len(self) |
| 140 | + for i, module in enumerate(modules): |
| 141 | + self.add_module(str(offset + i), module) |
| 142 | + return self |
| 143 | + |
| 144 | + def insert(self, index: int, module: ModuleType) -> None: |
| 145 | + r"""Insert a given module before a given index in the list. |
| 146 | +
|
| 147 | + Args: |
| 148 | + index (int): index to insert. |
| 149 | + module ( InfiniCoreModule): module to insert |
| 150 | + """ |
| 151 | + for i in range(len(self._modules), index, -1): |
| 152 | + self._modules[str(i)] = self._modules[str(i - 1)] |
| 153 | + self._modules[str(index)] = module |
| 154 | + |
| 155 | + def pop(self, idx: int = -1) -> ModuleType: |
| 156 | + r"""Remove and return a module at the given index. |
| 157 | +
|
| 158 | + Args: |
| 159 | + idx (int): index of the module to pop. Default: -1 (last module) |
| 160 | +
|
| 161 | + Returns: |
| 162 | + Module: the module that was removed |
| 163 | + """ |
| 164 | + idx_str = self._get_abs_string_index(idx) |
| 165 | + module = self._modules[idx_str] |
| 166 | + # Use __delitem__ to ensure proper cleanup |
| 167 | + self.__delitem__(int(idx_str)) |
| 168 | + return module |
| 169 | + |
| 170 | + def __repr__(self) -> str: |
| 171 | + """Return a string representation of the ModuleList.""" |
| 172 | + if len(self) == 0: |
| 173 | + return self.__class__.__name__ + "()" |
| 174 | + |
| 175 | + lines = [] |
| 176 | + for i, module in enumerate(self): |
| 177 | + lines.append(f"({i}): {repr(module)}") |
| 178 | + |
| 179 | + main_str = self.__class__.__name__ + "(\n " |
| 180 | + main_str += "\n ".join(lines) + "\n)" |
| 181 | + return main_str |
| 182 | + |
| 183 | + def __dir__(self) -> List[str]: |
| 184 | + """Return a list of attribute names, excluding numeric keys.""" |
| 185 | + keys = super().__dir__() |
| 186 | + # Filter out numeric keys to avoid cluttering dir() output |
| 187 | + keys = [key for key in keys if not key.isdigit()] |
| 188 | + return keys |
0 commit comments