Skip to content

Commit 879f212

Browse files
author
pengcheng888
committed
调整文件格式和parameter.py路径
1 parent cf2811f commit 879f212

File tree

5 files changed

+27
-27
lines changed

5 files changed

+27
-27
lines changed

python/infinicore/nn/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from infinicore.nn import (
2-
functional as functional,
3-
)
4-
from infinicore.nn import (
5-
modules as modules,
6-
)
7-
from infinicore.nn.functional import * # noqa: F403
1+
from infinicore.nn import functional
82
from infinicore.nn.modules import * # noqa: F403
3+
from infinicore.nn.parameter import Parameter
4+
5+
__all__ = ["functional", "Parameter"]
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .container import ModuleList
22
from .module import Module
3-
from .parameter import Parameter
43

5-
__all__ = ["ModuleList", "Module", "Parameter"]
4+
__all__ = ["ModuleList", "Module"]

python/infinicore/nn/modules/container.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
from .module import Module
1313

14+
__all__ = ["ModuleList"]
15+
16+
1417
# Define type variable for module compatibility (supports InfiniCoreModule)
1518
ModuleType = TypeVar("ModuleType", bound=Union["Module"])
1619

python/infinicore/nn/modules/module.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232

3333
import infinicore
3434

35-
from .parameter import Parameter
35+
from ...tensor import Tensor
36+
from ..parameter import Parameter
37+
38+
__all__ = ["Module"]
3639

3740
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
3841
T = TypeVar("T", bound="InfiniCoreModule")
@@ -57,7 +60,7 @@ class InfiniCoreModule:
5760

5861
_version: int = 1
5962
_parameters: Dict[str, Optional[Parameter]]
60-
_buffers: Dict[str, Optional[infinicore.Tensor]]
63+
_buffers: Dict[str, Optional[Tensor]]
6164
_non_persistent_buffers_set: Set[str]
6265
_modules: Dict[str, Optional["InfiniCoreModule"]]
6366

@@ -87,9 +90,7 @@ def __getattr__(self, name: str) -> Any:
8790
f"'{type(self).__name__}' object has no attribute '{name}'"
8891
)
8992

90-
def __setattr__(
91-
self, name: str, value: Union[infinicore.Tensor, "InfiniCoreModule"]
92-
) -> None:
93+
def __setattr__(self, name: str, value: Union[Tensor, "InfiniCoreModule"]) -> None:
9394
def remove_from(*dicts_or_sets) -> None:
9495
for d in dicts_or_sets:
9596
if name in d:
@@ -113,7 +114,7 @@ def remove_from(*dicts_or_sets) -> None:
113114
)
114115
self.register_parameter(name, value)
115116
elif name in params: # value will overwrite the name of params.
116-
if not isinstance(value, infinicore.Tensor):
117+
if not isinstance(value, Tensor):
117118
raise TypeError(
118119
f"cannot assign 'value' as parameter '{name}' (infinicore.nn.Parameter, Parameter or None expected)"
119120
)
@@ -141,7 +142,7 @@ def remove_from(*dicts_or_sets) -> None:
141142
else:
142143
buffers = self.__dict__.get("_buffers")
143144
if buffers is not None and name in buffers:
144-
if value is not None and not isinstance(value, infinicore.Tensor):
145+
if value is not None and not isinstance(value, Tensor):
145146
raise TypeError(
146147
f"cannot assign 'value' as buffer '{name}' "
147148
"(torch.Tensor or None expected)"
@@ -154,7 +155,7 @@ def __call__(self, *input, **kwargs):
154155
return self.forward(*input, **kwargs)
155156

156157
def register_buffer(
157-
self, name: str, tensor: Optional[infinicore.Tensor], persistent: bool = True
158+
self, name: str, tensor: Optional[Tensor], persistent: bool = True
158159
) -> None:
159160
r"""Adds a buffer to the module.
160161
@@ -187,7 +188,7 @@ def register_buffer(
187188
raise KeyError('buffer name can\'t be empty string ""')
188189
elif hasattr(self, name) and name not in self._buffers:
189190
raise KeyError("attribute '{}' already exists".format(name))
190-
elif tensor is not None and not isinstance(tensor, infinicore.Tensor):
191+
elif tensor is not None and not isinstance(tensor, Tensor):
191192
raise TypeError(
192193
"cannot assign '{}' object to buffer '{}' "
193194
"(torch Tensor or None required)".format("tensor", name)
@@ -256,7 +257,7 @@ def register_parameter(self, name: str, param: Parameter) -> None:
256257
if param is None:
257258
self._parameters[name] = None # 竟然可以是None
258259
else:
259-
if not isinstance(param, (Parameter, infinicore.Tensor)):
260+
if not isinstance(param, (Parameter, Tensor)):
260261
raise TypeError(
261262
f"cannot assign 'param' object to parameter '{name}' "
262263
"(infinicore.nn.Parameter, Parameter or None required)"
@@ -477,9 +478,9 @@ def _load_from_state_dict(
477478
input_param = state_dict[key]
478479

479480
# input_param must be of type infinicore.Tensor
480-
if not isinstance(input_param, infinicore.Tensor):
481+
if not isinstance(input_param, Tensor):
481482
raise TypeError(
482-
f"While copying the parameter named {key}, expected infinicore.Tensor from checkpoint but received {type(input_param)}"
483+
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
483484
)
484485

485486
if (
@@ -575,7 +576,7 @@ def load(module, local_state_dict, prefix=""):
575576
for k, v in local_state_dict.items()
576577
if k.startswith(child_prefix)
577578
}
578-
load(child, child_state_dict, child_prefix) # noqa: F821
579+
load(child, child_state_dict, child_prefix)
579580

580581
load(self, state_dict)
581582
del load
@@ -654,7 +655,7 @@ def named_parameters(
654655
for elem in gen:
655656
yield elem
656657

657-
def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
658+
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
658659
r"""Returns an iterator over module buffers.
659660
660661
Args:
@@ -677,7 +678,7 @@ def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
677678

678679
def named_buffers(
679680
self, prefix: str = "", recurse: bool = True
680-
) -> Iterator[Tuple[str, infinicore.Tensor]]:
681+
) -> Iterator[Tuple[str, Tensor]]:
681682
r"""Returns an iterator over module buffers, yielding both the
682683
name of the buffer as well as the buffer itself.
683684

python/infinicore/nn/modules/parameter.py renamed to python/infinicore/nn/parameter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# The use of this file is governed by the BSD 3-Clause License.
1414

1515

16-
import infinicore
16+
from ..tensor import Tensor
1717

1818

19-
class InfiniCoreParameter(infinicore.Tensor):
19+
class InfiniCoreParameter(Tensor):
2020
r"""A kind of Tensor that is to be considered a module parameter."""
2121

2222
def __init__(self, data=None):
23-
if not isinstance(data, infinicore.Tensor):
23+
if not isinstance(data, Tensor):
2424
raise ValueError("The `data` variable must be of type `infinicore.Tensor`.")
2525
super().__init__(data._underlying)
2626

0 commit comments

Comments
 (0)