3232
3333import infinicore
3434
35- from .parameter import Parameter
35+ from ...tensor import Tensor
36+ from ..parameter import InfiniCoreParameter as Parameter
3637
3738_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
3839T = TypeVar ("T" , bound = "InfiniCoreModule" )
@@ -57,7 +58,7 @@ class InfiniCoreModule:
5758
5859 _version : int = 1
5960 _parameters : Dict [str , Optional [Parameter ]]
60- _buffers : Dict [str , Optional [infinicore . Tensor ]]
61+ _buffers : Dict [str , Optional [Tensor ]]
6162 _non_persistent_buffers_set : Set [str ]
6263 _modules : Dict [str , Optional ["InfiniCoreModule" ]]
6364
@@ -87,9 +88,7 @@ def __getattr__(self, name: str) -> Any:
8788 f"'{ type (self ).__name__ } ' object has no attribute '{ name } '"
8889 )
8990
90- def __setattr__ (
91- self , name : str , value : Union [infinicore .Tensor , "InfiniCoreModule" ]
92- ) -> None :
91+ def __setattr__ (self , name : str , value : Union [Tensor , "InfiniCoreModule" ]) -> None :
9392 def remove_from (* dicts_or_sets ) -> None :
9493 for d in dicts_or_sets :
9594 if name in d :
@@ -113,7 +112,7 @@ def remove_from(*dicts_or_sets) -> None:
113112 )
114113 self .register_parameter (name , value )
115114 elif name in params : # value will overwrite the name of params.
116- if not isinstance (value , infinicore . Tensor ):
115+ if not isinstance (value , Tensor ):
117116 raise TypeError (
118117 f"cannot assign 'value' as parameter '{ name } ' (infinicore.nn.Parameter, Parameter or None expected)"
119118 )
@@ -141,7 +140,7 @@ def remove_from(*dicts_or_sets) -> None:
141140 else :
142141 buffers = self .__dict__ .get ("_buffers" )
143142 if buffers is not None and name in buffers :
144- if value is not None and not isinstance (value , infinicore . Tensor ):
143+ if value is not None and not isinstance (value , Tensor ):
145144 raise TypeError (
146145 f"cannot assign 'value' as buffer '{ name } ' "
147146 "(torch.Tensor or None expected)"
@@ -154,7 +153,7 @@ def __call__(self, *input, **kwargs):
154153 return self .forward (* input , ** kwargs )
155154
156155 def register_buffer (
157- self , name : str , tensor : Optional [infinicore . Tensor ], persistent : bool = True
156+ self , name : str , tensor : Optional [Tensor ], persistent : bool = True
158157 ) -> None :
159158 r"""Adds a buffer to the module.
160159
@@ -187,7 +186,7 @@ def register_buffer(
187186 raise KeyError ('buffer name can\' t be empty string ""' )
188187 elif hasattr (self , name ) and name not in self ._buffers :
189188 raise KeyError ("attribute '{}' already exists" .format (name ))
190- elif tensor is not None and not isinstance (tensor , infinicore . Tensor ):
189+ elif tensor is not None and not isinstance (tensor , Tensor ):
191190 raise TypeError (
192191 "cannot assign '{}' object to buffer '{}' "
193192 "(torch Tensor or None required)" .format ("tensor" , name )
@@ -256,7 +255,7 @@ def register_parameter(self, name: str, param: Parameter) -> None:
256255 if param is None :
257256 self ._parameters [name ] = None # 竟然可以是None
258257 else :
259- if not isinstance (param , (Parameter , infinicore . Tensor )):
258+ if not isinstance (param , (Parameter , Tensor )):
260259 raise TypeError (
261260 f"cannot assign 'param' object to parameter '{ name } ' "
262261 "(infinicore.nn.Parameter, Parameter or None required)"
@@ -477,9 +476,9 @@ def _load_from_state_dict(
477476 input_param = state_dict [key ]
478477
479478 # input_param must be of type infinicore.Tensor
480- if not isinstance (input_param , infinicore . Tensor ):
479+ if not isinstance (input_param , Tensor ):
481480 raise TypeError (
482- f"While copying the parameter named { key } , expected infinicore. Tensor from checkpoint but received { type (input_param )} "
481+ f"While copying the parameter named { key } , expected Tensor from checkpoint but received { type (input_param )} "
483482 )
484483
485484 if (
@@ -575,7 +574,7 @@ def load(module, local_state_dict, prefix=""):
575574 for k , v in local_state_dict .items ()
576575 if k .startswith (child_prefix )
577576 }
578- load (child , child_state_dict , child_prefix ) # noqa: F821
577+ load (child , child_state_dict , child_prefix )
579578
580579 load (self , state_dict )
581580 del load
@@ -654,7 +653,7 @@ def named_parameters(
654653 for elem in gen :
655654 yield elem
656655
657- def buffers (self , recurse : bool = True ) -> Iterator [infinicore . Tensor ]:
656+ def buffers (self , recurse : bool = True ) -> Iterator [Tensor ]:
658657 r"""Returns an iterator over module buffers.
659658
660659 Args:
@@ -677,7 +676,7 @@ def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
677676
678677 def named_buffers (
679678 self , prefix : str = "" , recurse : bool = True
680- ) -> Iterator [Tuple [str , infinicore . Tensor ]]:
679+ ) -> Iterator [Tuple [str , Tensor ]]:
681680 r"""Returns an iterator over module buffers, yielding both the
682681 name of the buffer as well as the buffer itself.
683682
@@ -856,6 +855,3 @@ def _apply(self, fn, recurse=True):
856855
857856 def to (self , * args , ** kwargs ):
858857 raise KeyError ("not support" )
859-
860-
861- Module = InfiniCoreModule
0 commit comments