Skip to content

Commit cb3f853

Browse files
authored
Merge pull request #2 from pengcheng888/feature/add_nn_interface
Feature/add nn interface
2 parents cf2811f + 058fd41 commit cb3f853

File tree

5 files changed

+25
-39
lines changed

5 files changed

+25
-39
lines changed

python/infinicore/nn/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from infinicore.nn import (
2-
functional as functional,
3-
)
4-
from infinicore.nn import (
5-
modules as modules,
6-
)
7-
from infinicore.nn.functional import * # noqa: F403
1+
from infinicore.nn import functional
82
from infinicore.nn.modules import * # noqa: F403
3+
from infinicore.nn.parameter import InfiniCoreParameter as Parameter
4+
5+
__all__ = ["functional", "Parameter"]
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .container import ModuleList
2-
from .module import Module
3-
from .parameter import Parameter
1+
from .container import InfiniCoreModuleList as ModuleList
2+
from .module import InfiniCoreModule as Module
43

5-
__all__ = ["ModuleList", "Module", "Parameter"]
4+
__all__ = ["ModuleList", "Module"]

python/infinicore/nn/modules/container.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import chain
1010
from typing import Iterator, List, Optional, Sequence, TypeVar, Union
1111

12-
from .module import Module
12+
from .module import InfiniCoreModule as Module
1313

1414
# Define type variable for module compatibility (supports InfiniCoreModule)
1515
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
@@ -186,6 +186,3 @@ def __dir__(self) -> List[str]:
186186
# Filter out numeric keys to avoid cluttering dir() output
187187
keys = [key for key in keys if not key.isdigit()]
188188
return keys
189-
190-
191-
ModuleList = InfiniCoreModuleList

python/infinicore/nn/modules/module.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232

3333
import infinicore
3434

35-
from .parameter import Parameter
35+
from ...tensor import Tensor
36+
from ..parameter import InfiniCoreParameter as Parameter
3637

3738
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
3839
T = TypeVar("T", bound="InfiniCoreModule")
@@ -57,7 +58,7 @@ class InfiniCoreModule:
5758

5859
_version: int = 1
5960
_parameters: Dict[str, Optional[Parameter]]
60-
_buffers: Dict[str, Optional[infinicore.Tensor]]
61+
_buffers: Dict[str, Optional[Tensor]]
6162
_non_persistent_buffers_set: Set[str]
6263
_modules: Dict[str, Optional["InfiniCoreModule"]]
6364

@@ -87,9 +88,7 @@ def __getattr__(self, name: str) -> Any:
8788
f"'{type(self).__name__}' object has no attribute '{name}'"
8889
)
8990

90-
def __setattr__(
91-
self, name: str, value: Union[infinicore.Tensor, "InfiniCoreModule"]
92-
) -> None:
91+
def __setattr__(self, name: str, value: Union[Tensor, "InfiniCoreModule"]) -> None:
9392
def remove_from(*dicts_or_sets) -> None:
9493
for d in dicts_or_sets:
9594
if name in d:
@@ -113,7 +112,7 @@ def remove_from(*dicts_or_sets) -> None:
113112
)
114113
self.register_parameter(name, value)
115114
elif name in params: # value will overwrite the name of params.
116-
if not isinstance(value, infinicore.Tensor):
115+
if not isinstance(value, Tensor):
117116
raise TypeError(
118117
f"cannot assign 'value' as parameter '{name}' (infinicore.nn.Parameter, Parameter or None expected)"
119118
)
@@ -141,7 +140,7 @@ def remove_from(*dicts_or_sets) -> None:
141140
else:
142141
buffers = self.__dict__.get("_buffers")
143142
if buffers is not None and name in buffers:
144-
if value is not None and not isinstance(value, infinicore.Tensor):
143+
if value is not None and not isinstance(value, Tensor):
145144
raise TypeError(
146145
f"cannot assign 'value' as buffer '{name}' "
147146
"(torch.Tensor or None expected)"
@@ -154,7 +153,7 @@ def __call__(self, *input, **kwargs):
154153
return self.forward(*input, **kwargs)
155154

156155
def register_buffer(
157-
self, name: str, tensor: Optional[infinicore.Tensor], persistent: bool = True
156+
self, name: str, tensor: Optional[Tensor], persistent: bool = True
158157
) -> None:
159158
r"""Adds a buffer to the module.
160159
@@ -187,7 +186,7 @@ def register_buffer(
187186
raise KeyError('buffer name can\'t be empty string ""')
188187
elif hasattr(self, name) and name not in self._buffers:
189188
raise KeyError("attribute '{}' already exists".format(name))
190-
elif tensor is not None and not isinstance(tensor, infinicore.Tensor):
189+
elif tensor is not None and not isinstance(tensor, Tensor):
191190
raise TypeError(
192191
"cannot assign '{}' object to buffer '{}' "
193192
"(torch Tensor or None required)".format("tensor", name)
@@ -256,7 +255,7 @@ def register_parameter(self, name: str, param: Parameter) -> None:
256255
if param is None:
257256
self._parameters[name] = None # 竟然可以是None
258257
else:
259-
if not isinstance(param, (Parameter, infinicore.Tensor)):
258+
if not isinstance(param, (Parameter, Tensor)):
260259
raise TypeError(
261260
f"cannot assign 'param' object to parameter '{name}' "
262261
"(infinicore.nn.Parameter, Parameter or None required)"
@@ -477,9 +476,9 @@ def _load_from_state_dict(
477476
input_param = state_dict[key]
478477

479478
# input_param must be of type infinicore.Tensor
480-
if not isinstance(input_param, infinicore.Tensor):
479+
if not isinstance(input_param, Tensor):
481480
raise TypeError(
482-
f"While copying the parameter named {key}, expected infinicore.Tensor from checkpoint but received {type(input_param)}"
481+
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
483482
)
484483

485484
if (
@@ -575,7 +574,7 @@ def load(module, local_state_dict, prefix=""):
575574
for k, v in local_state_dict.items()
576575
if k.startswith(child_prefix)
577576
}
578-
load(child, child_state_dict, child_prefix) # noqa: F821
577+
load(child, child_state_dict, child_prefix)
579578

580579
load(self, state_dict)
581580
del load
@@ -654,7 +653,7 @@ def named_parameters(
654653
for elem in gen:
655654
yield elem
656655

657-
def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
656+
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
658657
r"""Returns an iterator over module buffers.
659658
660659
Args:
@@ -677,7 +676,7 @@ def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
677676

678677
def named_buffers(
679678
self, prefix: str = "", recurse: bool = True
680-
) -> Iterator[Tuple[str, infinicore.Tensor]]:
679+
) -> Iterator[Tuple[str, Tensor]]:
681680
r"""Returns an iterator over module buffers, yielding both the
682681
name of the buffer as well as the buffer itself.
683682
@@ -856,6 +855,3 @@ def _apply(self, fn, recurse=True):
856855

857856
def to(self, *args, **kwargs):
858857
raise KeyError("not support")
859-
860-
861-
Module = InfiniCoreModule

python/infinicore/nn/modules/parameter.py renamed to python/infinicore/nn/parameter.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# The use of this file is governed by the BSD 3-Clause License.
1414

1515

16-
import infinicore
16+
from ..tensor import Tensor
1717

1818

19-
class InfiniCoreParameter(infinicore.Tensor):
19+
class InfiniCoreParameter(Tensor):
2020
r"""A kind of Tensor that is to be considered a module parameter."""
2121

2222
def __init__(self, data=None):
23-
if not isinstance(data, infinicore.Tensor):
23+
if not isinstance(data, Tensor):
2424
raise ValueError("The `data` variable must be of type `infinicore.Tensor`.")
2525
super().__init__(data._underlying)
2626

@@ -32,6 +32,3 @@ def __deepcopy__(self, memo):
3232

3333
def __reduce_ex__(self, proto):
3434
raise ValueError("not supported!")
35-
36-
37-
Parameter = InfiniCoreParameter

0 commit comments

Comments
 (0)