3232
3333import infinicore
3434
35- from .parameter import Parameter
35+ from ...tensor import Tensor
36+ from ..parameter import Parameter
37+
38+ __all__ = ["Module" ]
3639
3740_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
3841T = TypeVar ("T" , bound = "InfiniCoreModule" )
@@ -57,7 +60,7 @@ class InfiniCoreModule:
5760
5861 _version : int = 1
5962 _parameters : Dict [str , Optional [Parameter ]]
60- _buffers : Dict [str , Optional [infinicore . Tensor ]]
63+ _buffers : Dict [str , Optional [Tensor ]]
6164 _non_persistent_buffers_set : Set [str ]
6265 _modules : Dict [str , Optional ["InfiniCoreModule" ]]
6366
@@ -87,9 +90,7 @@ def __getattr__(self, name: str) -> Any:
8790 f"'{ type (self ).__name__ } ' object has no attribute '{ name } '"
8891 )
8992
90- def __setattr__ (
91- self , name : str , value : Union [infinicore .Tensor , "InfiniCoreModule" ]
92- ) -> None :
93+ def __setattr__ (self , name : str , value : Union [Tensor , "InfiniCoreModule" ]) -> None :
9394 def remove_from (* dicts_or_sets ) -> None :
9495 for d in dicts_or_sets :
9596 if name in d :
@@ -113,7 +114,7 @@ def remove_from(*dicts_or_sets) -> None:
113114 )
114115 self .register_parameter (name , value )
115116 elif name in params : # value will overwrite the name of params.
116- if not isinstance (value , infinicore . Tensor ):
117+ if not isinstance (value , Tensor ):
117118 raise TypeError (
118119 f"cannot assign 'value' as parameter '{ name } ' (infinicore.nn.Parameter, Parameter or None expected)"
119120 )
@@ -141,7 +142,7 @@ def remove_from(*dicts_or_sets) -> None:
141142 else :
142143 buffers = self .__dict__ .get ("_buffers" )
143144 if buffers is not None and name in buffers :
144- if value is not None and not isinstance (value , infinicore . Tensor ):
145+ if value is not None and not isinstance (value , Tensor ):
145146 raise TypeError (
146147 f"cannot assign 'value' as buffer '{ name } ' "
147148 "(torch.Tensor or None expected)"
@@ -154,7 +155,7 @@ def __call__(self, *input, **kwargs):
154155 return self .forward (* input , ** kwargs )
155156
156157 def register_buffer (
157- self , name : str , tensor : Optional [infinicore . Tensor ], persistent : bool = True
158+ self , name : str , tensor : Optional [Tensor ], persistent : bool = True
158159 ) -> None :
159160 r"""Adds a buffer to the module.
160161
@@ -187,7 +188,7 @@ def register_buffer(
187188 raise KeyError ('buffer name can\' t be empty string ""' )
188189 elif hasattr (self , name ) and name not in self ._buffers :
189190 raise KeyError ("attribute '{}' already exists" .format (name ))
190- elif tensor is not None and not isinstance (tensor , infinicore . Tensor ):
191+ elif tensor is not None and not isinstance (tensor , Tensor ):
191192 raise TypeError (
192193 "cannot assign '{}' object to buffer '{}' "
193194 "(torch Tensor or None required)" .format ("tensor" , name )
@@ -256,7 +257,7 @@ def register_parameter(self, name: str, param: Parameter) -> None:
256257 if param is None :
257258 self ._parameters [name ] = None # 竟然可以是None
258259 else :
259- if not isinstance (param , (Parameter , infinicore . Tensor )):
260+ if not isinstance (param , (Parameter , Tensor )):
260261 raise TypeError (
261262 f"cannot assign 'param' object to parameter '{ name } ' "
262263 "(infinicore.nn.Parameter, Parameter or None required)"
@@ -477,9 +478,9 @@ def _load_from_state_dict(
477478 input_param = state_dict [key ]
478479
479480 # input_param must be of type infinicore.Tensor
480- if not isinstance (input_param , infinicore . Tensor ):
481+ if not isinstance (input_param , Tensor ):
481482 raise TypeError (
482- f"While copying the parameter named { key } , expected infinicore. Tensor from checkpoint but received { type (input_param )} "
483+ f"While copying the parameter named { key } , expected Tensor from checkpoint but received { type (input_param )} "
483484 )
484485
485486 if (
@@ -575,7 +576,7 @@ def load(module, local_state_dict, prefix=""):
575576 for k , v in local_state_dict .items ()
576577 if k .startswith (child_prefix )
577578 }
578- load (child , child_state_dict , child_prefix ) # noqa: F821
579+ load (child , child_state_dict , child_prefix )
579580
580581 load (self , state_dict )
581582 del load
@@ -654,7 +655,7 @@ def named_parameters(
654655 for elem in gen :
655656 yield elem
656657
657- def buffers (self , recurse : bool = True ) -> Iterator [infinicore . Tensor ]:
658+ def buffers (self , recurse : bool = True ) -> Iterator [Tensor ]:
658659 r"""Returns an iterator over module buffers.
659660
660661 Args:
@@ -677,7 +678,7 @@ def buffers(self, recurse: bool = True) -> Iterator[infinicore.Tensor]:
677678
678679 def named_buffers (
679680 self , prefix : str = "" , recurse : bool = True
680- ) -> Iterator [Tuple [str , infinicore . Tensor ]]:
681+ ) -> Iterator [Tuple [str , Tensor ]]:
681682 r"""Returns an iterator over module buffers, yielding both the
682683 name of the buffer as well as the buffer itself.
683684
0 commit comments