@@ -801,13 +801,12 @@ def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, Non
801
801
802
802
state_dict_config = ShardedStateDictConfig (offload_to_cpu = True )
803
803
optim_state_dict_config = ShardedOptimStateDictConfig (offload_to_cpu = True )
804
- state_dict_type_context = FSDP .state_dict_type (
804
+ return FSDP .state_dict_type (
805
805
module = module ,
806
806
state_dict_type = StateDictType .SHARDED_STATE_DICT ,
807
807
state_dict_config = state_dict_config ,
808
808
optim_state_dict_config = optim_state_dict_config ,
809
809
)
810
- return state_dict_type_context # type: ignore[return-value]
811
810
812
811
813
812
def _get_full_state_dict_context (
@@ -819,15 +818,13 @@ def _get_full_state_dict_context(
819
818
820
819
state_dict_config = FullStateDictConfig (offload_to_cpu = True , rank0_only = rank0_only )
821
820
optim_state_dict_config = FullOptimStateDictConfig (offload_to_cpu = True , rank0_only = rank0_only )
822
- state_dict_type_context = FSDP .state_dict_type (
821
+ return FSDP .state_dict_type (
823
822
module = module ,
824
823
state_dict_type = StateDictType .FULL_STATE_DICT ,
825
824
state_dict_config = state_dict_config ,
826
825
optim_state_dict_config = optim_state_dict_config ,
827
826
)
828
827
829
- return state_dict_type_context # type: ignore[return-value]
830
-
831
828
832
829
def _is_sharded_checkpoint (path : Path ) -> bool :
833
830
"""A heuristic check to determine whether the path points to a directory with checkpoint shards."""
0 commit comments