1616from collections import OrderedDict , namedtuple
1717import itertools
1818import warnings
19+ from typing import TYPE_CHECKING
1920
2021import torch
2122
2223from typing import Union , Tuple , Any , Iterator , Set , Optional , overload , TypeVar , Mapping , Dict , List
2324from torch .utils ._python_dispatch import is_traceable_wrapper_subclass
2425
26+ if TYPE_CHECKING :
27+ from .parameter import InfiniCoreParameter as Parameter
28+
2529_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
2630
2731T = TypeVar ('T' , bound = 'InfiniCoreModule' )
@@ -46,7 +50,7 @@ class InfiniCoreModule:
4650 _version : int = 1
4751
4852 training : bool
49- _parameters : Dict [str , Optional [Union [torch .nn .Parameter , 'InfiniCoreParameter ' ]]]
53+ _parameters : Dict [str , Optional [Union [torch .nn .Parameter , 'Parameter ' ]]]
5054 _buffers : Dict [str , Optional [torch .Tensor ]]
5155 _non_persistent_buffers_set : Set [str ]
5256 _modules : Dict [str , Optional ['InfiniCoreModule' ]]
@@ -84,9 +88,9 @@ def remove_from(*dicts_or_sets) -> None:
8488 d .discard (name )
8589
8690 params = self .__dict__ .get ("_parameters" )
87- # Support both torch.nn.Parameter and InfiniCoreParameter
88- from .parameter import InfiniCoreParameter
89- if isinstance (value , (torch .nn .Parameter , InfiniCoreParameter )):
91+ # Support both torch.nn.Parameter and Parameter ( InfiniCoreParameter)
92+ from .parameter import InfiniCoreParameter as Parameter
93+ if isinstance (value , (torch .nn .Parameter , Parameter )):
9094 if params is None :
9195 raise AttributeError (
9296 "cannot assign parameters before Module.__init__() call"
@@ -102,7 +106,7 @@ def remove_from(*dicts_or_sets) -> None:
102106 if value is not None :
103107 raise TypeError (
104108 f"cannot assign '{ torch .typename (value )} ' as parameter '{ name } ' "
105- "(torch.nn.Parameter, InfiniCoreParameter or None expected)"
109+ "(torch.nn.Parameter, Parameter or None expected)"
106110 )
107111 self .register_parameter (name , value )
108112 else :
@@ -210,7 +214,7 @@ def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
210214
211215 self ._modules [name ] = module
212216
213- def register_parameter (self , name : str , param : Optional [torch .nn .Parameter ]) -> None :
217+ def register_parameter (self , name : str , param : Optional [Union [ torch .nn .Parameter , 'Parameter' ] ]) -> None :
214218 r"""Add a parameter to the module.
215219
216220 The parameter can be accessed as an attribute using given name.
@@ -242,12 +246,12 @@ def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) ->
242246 if param is None :
243247 self ._parameters [name ] = None
244248 else :
245- # Support both torch.nn.Parameter and InfiniCoreParameter
246- from .parameter import InfiniCoreParameter
247- if not isinstance (param , (torch .nn .Parameter , InfiniCoreParameter )):
249+ # Support both torch.nn.Parameter and Parameter ( InfiniCoreParameter)
250+ from .parameter import InfiniCoreParameter as Parameter
251+ if not isinstance (param , (torch .nn .Parameter , Parameter )):
248252 raise TypeError (
249253 f"cannot assign '{ torch .typename (param )} ' object to parameter '{ name } ' "
250- "(torch.nn.Parameter, InfiniCoreParameter or None required)"
254+ "(torch.nn.Parameter, Parameter or None required)"
251255 )
252256 self ._parameters [name ] = param
253257
@@ -557,7 +561,7 @@ def load(module, local_state_dict, prefix=''):
557561 self .__class__ .__name__ , "\n \t " .join (error_msgs )))
558562 return _IncompatibleKeys (missing_keys , unexpected_keys )
559563
560- def parameters (self , recurse : bool = True ) -> Iterator [torch .nn .Parameter ]:
564+ def parameters (self , recurse : bool = True ) -> Iterator [Union [ torch .nn .Parameter , 'Parameter' ] ]:
561565 r"""Returns an iterator over module parameters.
562566
563567 Args:
@@ -578,7 +582,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]:
578582 for name , param in self .named_parameters (recurse = recurse ):
579583 yield param
580584
581- def named_parameters (self , prefix : str = '' , recurse : bool = True ) -> Iterator [Tuple [str , torch .nn .Parameter ]]:
585+ def named_parameters (self , prefix : str = '' , recurse : bool = True ) -> Iterator [Tuple [str , Union [ torch .nn .Parameter , 'Parameter' ] ]]:
582586 r"""Returns an iterator over module parameters, yielding both the
583587 name of the parameter as well as the parameter itself.
584588
@@ -845,6 +849,9 @@ def compute_should_use_set_data(tensor, tensor_applied):
845849 return False
846850
847851 should_use_swap_tensors = torch .__future__ .get_swap_module_params_on_conversion ()
852+
853+ # Import Parameter (InfiniCoreParameter) for type checking and creation
854+ from .parameter import InfiniCoreParameter as Parameter
848855
849856 for key , param in self ._parameters .items ():
850857 if param is None :
@@ -859,14 +866,18 @@ def compute_should_use_set_data(tensor, tensor_applied):
859866 # subclasses may have multiple child tensors so we need to use swap_tensors
860867 p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass (param_applied )
861868
869+ # Determine the Parameter class to use based on the original parameter type
870+ is_infinicore_param = isinstance (param , Parameter )
871+ ParamClass = Parameter if is_infinicore_param else torch .nn .Parameter
872+
862873 param_grad = param .grad
863874 if p_should_use_swap_tensors :
864875 try :
865876 if param_grad is not None :
866877 # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
867878 # Decrement use count of the gradient by setting to None
868879 param .grad = None
869- param_applied = torch . nn . Parameter (param_applied , requires_grad = param .requires_grad )
880+ param_applied = ParamClass (param_applied , requires_grad = param .requires_grad )
870881 torch .utils .swap_tensors (param , param_applied )
871882 except Exception as e :
872883 if param_grad is not None :
@@ -877,9 +888,9 @@ def compute_should_use_set_data(tensor, tensor_applied):
877888 param .data = param_applied
878889 out_param = param
879890 else :
880- assert isinstance (param , torch .nn .Parameter )
891+ assert isinstance (param , ( torch .nn .Parameter , Parameter ) )
881892 assert param .is_leaf
882- out_param = torch . nn . Parameter (param_applied , param .requires_grad )
893+ out_param = ParamClass (param_applied , param .requires_grad )
883894 self ._parameters [key ] = out_param
884895
885896 if param_grad is not None :
0 commit comments