Skip to content

Commit d4a9bea

Browse files
authored
Merge pull request #1 from pengcheng888/feature/add_nn_interface_wpc
issue/567-只处理infinicore.Tensor,能够加载infinicore.Tensor的权重,修改了module.py …
2 parents 4e35d6a + 72a61fd commit d4a9bea

File tree

11 files changed

+837
-1110
lines changed

11 files changed

+837
-1110
lines changed

python/infinicore/nn/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
from infinicore.nn import (
22
functional as functional,
33
)
4+
from infinicore.nn import (
5+
modules as modules,
6+
)
7+
from infinicore.nn.functional import * # noqa: F403
8+
from infinicore.nn.modules import * # noqa: F403
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .module import InfiniCoreModule as Module
2-
from .module_list import InfiniCoreModuleList as ModuleList
3-
from .parameter import InfiniCoreParameter as Parameter
1+
from .container import ModuleList
2+
from .module import Module
3+
from .parameter import Parameter

python/infinicore/nn/modules/module_list.py renamed to python/infinicore/nn/modules/container.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
# ============================================
12
# Copyright (c) 2025, InfiniCore
2-
#
3+
#
34
# This file implements InfiniCoreModuleList, which is similar to torch.nn.ModuleList
45
# but based on InfiniCoreModule for inference purposes.
56

6-
from typing import List, Optional, Iterator, Union, Sequence, TypeVar
7-
import torch
87
import operator
9-
from itertools import chain
108
from collections import OrderedDict
11-
from .module import InfiniCoreModule
9+
from itertools import chain
10+
from typing import Iterator, List, Optional, Sequence, TypeVar, Union
11+
12+
from .module import Module
1213

13-
# Define type variable for module compatibility (supports both torch.nn.Module and InfiniCoreModule)
14-
ModuleType = TypeVar('ModuleType', bound=Union[torch.nn.Module, 'InfiniCoreModule'])
14+
# Define type variable for module compatibility (supports InfiniCoreModule)
15+
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
1516

1617

17-
class InfiniCoreModuleList(InfiniCoreModule):
18+
class InfiniCoreModuleList(Module):
1819
r"""Holds submodules in a list.
1920
2021
InfiniCoreModuleList can be indexed like a regular Python list, but
@@ -54,7 +55,9 @@ def _get_abs_string_index(self, idx):
5455
idx += len(self)
5556
return str(idx)
5657

57-
def __getitem__(self, idx: Union[int, slice]) -> Union[ModuleType, 'InfiniCoreModuleList']:
58+
def __getitem__(
59+
self, idx: Union[int, slice]
60+
) -> Union[ModuleType, "InfiniCoreModuleList"]:
5861
if isinstance(idx, slice):
5962
return self.__class__(list(self._modules.values())[idx])
6063
else:
@@ -75,7 +78,7 @@ def __delitem__(self, idx: Union[int, slice]) -> None:
7578
idx_str = self._get_abs_string_index(idx)
7679
if idx_str in self._modules:
7780
del self._modules[idx_str]
78-
81+
7982
# To preserve numbering, self._modules is being reconstructed with modules after deletion
8083
if len(self._modules) > 0:
8184
str_indices = [str(i) for i in range(len(self._modules))]
@@ -87,10 +90,12 @@ def __len__(self) -> int:
8790
def __iter__(self) -> Iterator[ModuleType]:
8891
return iter(self._modules.values())
8992

90-
def __iadd__(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
93+
def __iadd__(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
9194
return self.extend(modules)
9295

93-
def __add__(self, other: Union[Sequence[ModuleType], 'InfiniCoreModuleList']) -> 'InfiniCoreModuleList':
96+
def __add__(
97+
self, other: Union[Sequence[ModuleType], "InfiniCoreModuleList"]
98+
) -> "InfiniCoreModuleList":
9499
r"""Return a new InfiniCoreModuleList by concatenating with another iterable.
95100
96101
Args:
@@ -101,22 +106,22 @@ def __add__(self, other: Union[Sequence[ModuleType], 'InfiniCoreModuleList']) ->
101106
f"InfiniCoreModuleList can only be concatenated with list, tuple, or InfiniCoreModuleList, "
102107
f"got {type(other).__name__}"
103108
)
104-
109+
105110
combined = InfiniCoreModuleList()
106111
for i, module in enumerate(chain(self, other)):
107112
combined.add_module(str(i), module)
108113
return combined
109114

110-
def append(self, module: ModuleType) -> 'InfiniCoreModuleList':
115+
def append(self, module: ModuleType) -> "InfiniCoreModuleList":
111116
r"""Append a given module to the end of the list.
112117
113118
Args:
114-
module (nn.Module or InfiniCoreModule): module to append
119+
module (InfiniCoreModule): module to append
115120
"""
116121
self.add_module(str(len(self)), module)
117122
return self
118123

119-
def extend(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
124+
def extend(self, modules: Sequence[ModuleType]) -> "InfiniCoreModuleList":
120125
r"""Append modules from a Python iterable to the end of the list.
121126
122127
Args:
@@ -130,7 +135,7 @@ def extend(self, modules: Sequence[ModuleType]) -> 'InfiniCoreModuleList':
130135
f"InfiniCoreModuleList.extend should be called with an "
131136
f"iterable, but got {type(modules).__name__}"
132137
)
133-
138+
134139
offset = len(self)
135140
for i, module in enumerate(modules):
136141
self.add_module(str(offset + i), module)
@@ -141,7 +146,7 @@ def insert(self, index: int, module: ModuleType) -> None:
141146
142147
Args:
143148
index (int): index to insert.
144-
module (nn.Module or InfiniCoreModule): module to insert
149+
module ( InfiniCoreModule): module to insert
145150
"""
146151
for i in range(len(self._modules), index, -1):
147152
self._modules[str(i)] = self._modules[str(i - 1)]
@@ -166,11 +171,11 @@ def __repr__(self) -> str:
166171
"""Return a string representation of the ModuleList."""
167172
if len(self) == 0:
168173
return self.__class__.__name__ + "()"
169-
174+
170175
lines = []
171176
for i, module in enumerate(self):
172177
lines.append(f"({i}): {repr(module)}")
173-
178+
174179
main_str = self.__class__.__name__ + "(\n "
175180
main_str += "\n ".join(lines) + "\n)"
176181
return main_str
@@ -181,3 +186,6 @@ def __dir__(self) -> List[str]:
181186
# Filter out numeric keys to avoid cluttering dir() output
182187
keys = [key for key in keys if not key.isdigit()]
183188
return keys
189+
190+
191+
ModuleList = InfiniCoreModuleList

0 commit comments

Comments
 (0)