@@ -801,13 +801,12 @@ def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, Non
801801
802802 state_dict_config = ShardedStateDictConfig (offload_to_cpu = True )
803803 optim_state_dict_config = ShardedOptimStateDictConfig (offload_to_cpu = True )
804- state_dict_type_context = FSDP .state_dict_type (
804+ return FSDP .state_dict_type (
805805 module = module ,
806806 state_dict_type = StateDictType .SHARDED_STATE_DICT ,
807807 state_dict_config = state_dict_config ,
808808 optim_state_dict_config = optim_state_dict_config ,
809809 )
810- return state_dict_type_context # type: ignore[return-value]
811810
812811
813812def _get_full_state_dict_context (
@@ -819,15 +818,13 @@ def _get_full_state_dict_context(
819818
820819 state_dict_config = FullStateDictConfig (offload_to_cpu = True , rank0_only = rank0_only )
821820 optim_state_dict_config = FullOptimStateDictConfig (offload_to_cpu = True , rank0_only = rank0_only )
822- state_dict_type_context = FSDP .state_dict_type (
821+ return FSDP .state_dict_type (
823822 module = module ,
824823 state_dict_type = StateDictType .FULL_STATE_DICT ,
825824 state_dict_config = state_dict_config ,
826825 optim_state_dict_config = optim_state_dict_config ,
827826 )
828827
829- return state_dict_type_context # type: ignore[return-value]
830-
831828
832829def _is_sharded_checkpoint (path : Path ) -> bool :
833830 """A heuristic check to determine whether the path points to a directory with checkpoint shards."""
0 commit comments