@@ -79,13 +79,13 @@ def fn_(obj: Any) -> None:
7979 obj .detach_ ().requires_grad_ (requires_grad )
8080
8181 if isinstance (target , ModuleState ):
82- true_target = cast (TensorTree , (target .params , target .buffers ))
82+ true_target = cast (' TensorTree' , (target .params , target .buffers ))
8383 elif isinstance (target , nn .Module ):
84- true_target = cast (TensorTree , tuple (target .parameters ()))
84+ true_target = cast (' TensorTree' , tuple (target .parameters ()))
8585 elif isinstance (target , MetaOptimizer ):
86- true_target = cast (TensorTree , target .state_dict ())
86+ true_target = cast (' TensorTree' , target .state_dict ())
8787 else :
88- true_target = cast (TensorTree , target ) # tree of tensors
88+ true_target = cast (' TensorTree' , target ) # tree of tensors
8989
9090 pytree .tree_map_ (fn_ , true_target )
9191
@@ -325,7 +325,7 @@ def recover_state_dict(
325325 from torchopt .optim .meta .base import MetaOptimizer
326326
327327 if isinstance (target , nn .Module ):
328- params , buffers , * _ = state = cast (ModuleState , state )
328+ params , buffers , * _ = state = cast (' ModuleState' , state )
329329 params_containers , buffers_containers = extract_module_containers (target , with_buffers = True )
330330
331331 if state .detach_buffers :
@@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
343343 ):
344344 tgt .update (src )
345345 elif isinstance (target , MetaOptimizer ):
346- state = cast (Sequence [OptState ], state )
346+ state = cast (' Sequence[OptState]' , state )
347347 target .load_state_dict (state )
348348 else :
349349 raise TypeError (f'Unexpected class of { target } ' )
@@ -422,9 +422,9 @@ def module_clone( # noqa: C901
422422
423423 if isinstance (target , (nn .Module , MetaOptimizer )):
424424 if isinstance (target , nn .Module ):
425- containers = cast (TensorTree , extract_module_containers (target , with_buffers = True ))
425+ containers = cast (' TensorTree' , extract_module_containers (target , with_buffers = True ))
426426 else :
427- containers = cast (TensorTree , target .state_dict ())
427+ containers = cast (' TensorTree' , target .state_dict ())
428428 tensors = pytree .tree_leaves (containers )
429429 memo = {id (t ): t for t in tensors }
430430 cloned = copy .deepcopy (target , memo = memo )
@@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
476476 else :
477477 replicate = clone_detach_
478478
479- return pytree .tree_map (replicate , cast (TensorTree , target ))
479+ return pytree .tree_map (replicate , cast (' TensorTree' , target ))
480480
481481
482482@overload
0 commit comments