Skip to content

Commit 4f473b5

Browse files
zhuyuegongchensu
authored andcommitted
Support InfiniCoreParam used in InfiniCoreModule.
1 parent f9492c1 commit 4f473b5

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

python/infinicore/nn/modules/module.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
from collections import OrderedDict, namedtuple
1717
import itertools
1818
import warnings
19+
from typing import TYPE_CHECKING
1920

2021
import torch
2122

2223
from typing import Union, Tuple, Any, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
2324
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
2425

26+
if TYPE_CHECKING:
27+
from .parameter import InfiniCoreParameter as Parameter
28+
2529
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
2630

2731
T = TypeVar('T', bound='InfiniCoreModule')
@@ -46,7 +50,7 @@ class InfiniCoreModule:
4650
_version: int = 1
4751

4852
training: bool
49-
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'InfiniCoreParameter']]]
53+
_parameters: Dict[str, Optional[Union[torch.nn.Parameter, 'Parameter']]]
5054
_buffers: Dict[str, Optional[torch.Tensor]]
5155
_non_persistent_buffers_set: Set[str]
5256
_modules: Dict[str, Optional['InfiniCoreModule']]
@@ -84,9 +88,9 @@ def remove_from(*dicts_or_sets) -> None:
8488
d.discard(name)
8589

8690
params = self.__dict__.get("_parameters")
87-
# Support both torch.nn.Parameter and InfiniCoreParameter
88-
from .parameter import InfiniCoreParameter
89-
if isinstance(value, (torch.nn.Parameter, InfiniCoreParameter)):
91+
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
92+
from .parameter import InfiniCoreParameter as Parameter
93+
if isinstance(value, (torch.nn.Parameter, Parameter)):
9094
if params is None:
9195
raise AttributeError(
9296
"cannot assign parameters before Module.__init__() call"
@@ -102,7 +106,7 @@ def remove_from(*dicts_or_sets) -> None:
102106
if value is not None:
103107
raise TypeError(
104108
f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
105-
"(torch.nn.Parameter, InfiniCoreParameter or None expected)"
109+
"(torch.nn.Parameter, Parameter or None expected)"
106110
)
107111
self.register_parameter(name, value)
108112
else:
@@ -210,7 +214,7 @@ def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
210214

211215
self._modules[name] = module
212216

213-
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
217+
def register_parameter(self, name: str, param: Optional[Union[torch.nn.Parameter, 'Parameter']]) -> None:
214218
r"""Add a parameter to the module.
215219
216220
The parameter can be accessed as an attribute using given name.
@@ -242,12 +246,12 @@ def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) ->
242246
if param is None:
243247
self._parameters[name] = None
244248
else:
245-
# Support both torch.nn.Parameter and InfiniCoreParameter
246-
from .parameter import InfiniCoreParameter
247-
if not isinstance(param, (torch.nn.Parameter, InfiniCoreParameter)):
249+
# Support both torch.nn.Parameter and Parameter (InfiniCoreParameter)
250+
from .parameter import InfiniCoreParameter as Parameter
251+
if not isinstance(param, (torch.nn.Parameter, Parameter)):
248252
raise TypeError(
249253
f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
250-
"(torch.nn.Parameter, InfiniCoreParameter or None required)"
254+
"(torch.nn.Parameter, Parameter or None required)"
251255
)
252256
self._parameters[name] = param
253257

@@ -557,7 +561,7 @@ def load(module, local_state_dict, prefix=''):
557561
self.__class__.__name__, "\n\t".join(error_msgs)))
558562
return _IncompatibleKeys(missing_keys, unexpected_keys)
559563

560-
def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
564+
def parameters(self, recurse: bool = True) -> Iterator[Union[torch.nn.Parameter, 'Parameter']]:
561565
r"""Returns an iterator over module parameters.
562566
563567
Args:
@@ -578,7 +582,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
578582
for name, param in self.named_parameters(recurse=recurse):
579583
yield param
580584

581-
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
585+
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Union[torch.nn.Parameter, 'Parameter']]]:
582586
r"""Returns an iterator over module parameters, yielding both the
583587
name of the parameter as well as the parameter itself.
584588
@@ -845,6 +849,9 @@ def compute_should_use_set_data(tensor, tensor_applied):
845849
return False
846850

847851
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
852+
853+
# Import Parameter (InfiniCoreParameter) for type checking and creation
854+
from .parameter import InfiniCoreParameter as Parameter
848855

849856
for key, param in self._parameters.items():
850857
if param is None:
@@ -859,14 +866,18 @@ def compute_should_use_set_data(tensor, tensor_applied):
859866
# subclasses may have multiple child tensors so we need to use swap_tensors
860867
p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
861868

869+
# Determine the Parameter class to use based on the original parameter type
870+
is_infinicore_param = isinstance(param, Parameter)
871+
ParamClass = Parameter if is_infinicore_param else torch.nn.Parameter
872+
862873
param_grad = param.grad
863874
if p_should_use_swap_tensors:
864875
try:
865876
if param_grad is not None:
866877
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
867878
# Decrement use count of the gradient by setting to None
868879
param.grad = None
869-
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
880+
param_applied = ParamClass(param_applied, requires_grad=param.requires_grad)
870881
torch.utils.swap_tensors(param, param_applied)
871882
except Exception as e:
872883
if param_grad is not None:
@@ -877,9 +888,9 @@ def compute_should_use_set_data(tensor, tensor_applied):
877888
param.data = param_applied
878889
out_param = param
879890
else:
880-
assert isinstance(param, torch.nn.Parameter)
891+
assert isinstance(param, (torch.nn.Parameter, Parameter))
881892
assert param.is_leaf
882-
out_param = torch.nn.Parameter(param_applied, param.requires_grad)
893+
out_param = ParamClass(param_applied, param.requires_grad)
883894
self._parameters[key] = out_param
884895

885896
if param_grad is not None:

0 commit comments

Comments
 (0)