@@ -780,19 +780,19 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
780
780
zero_state = dict ()
781
781
device = get_accelerator ().get_current_device ()
782
782
for param , state in self .optim .state .items ():
783
- if pinned_state_dicts and param not in pinned_state_dicts :
783
+ if pinned_state_dicts is not None and param not in pinned_state_dicts :
784
784
pinned_state_dicts [param ] = {}
785
785
zero_state [param ] = copy .deepcopy (state )
786
786
for k , v in state .items ():
787
787
if isinstance (v , torch .Tensor ) and k != "step" :
788
- if pinned_state_dicts and k not in pinned_state_dicts [param ]:
789
- pinned_state_dicts [param ][k ] = torch .empty_like (working_param , pin_memory = True , device = "cpu" )
790
788
working_param = self .master_to_working_param [id (param )]
791
789
pg = self .param_to_pg [working_param ]
792
790
gathered_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
793
791
all_gather_into_flat_tensor_nd (gathered_tensor , v .to (device ).flatten (), pg )
794
792
param_state = gathered_tensor [: working_param .numel ()].reshape_as (working_param )
795
- if pinned_state_dicts :
793
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param ]:
794
+ pinned_state_dicts [param ][k ] = torch .empty_like (param_state , pin_memory = True , device = "cpu" )
795
+ if pinned_state_dicts is not None :
796
796
pinned_state_dicts [param ][k ].copy_ (param_state )
797
797
zero_state [param ][k ] = pinned_state_dicts [param ][k ]
798
798
else :
@@ -858,7 +858,7 @@ def state_dict_shard(
858
858
for param_idx , states in local_states .items ():
859
859
current_block_size = 0
860
860
current_block = copy .deepcopy (states )
861
- if pinned_state_dicts and param_idx not in pinned_state_dicts :
861
+ if pinned_state_dicts is not None and param_idx not in pinned_state_dicts :
862
862
pinned_state_dicts [param_idx ] = {}
863
863
master_param = idx2master [param_idx ]
864
864
working_param = self .master_to_working_param [id (master_param )]
@@ -869,9 +869,9 @@ def state_dict_shard(
869
869
state_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
870
870
all_gather_into_flat_tensor_nd (state_tensor , v .to (device ).flatten (), pg )
871
871
state_tensor = state_tensor [: working_param .numel ()].reshape_as (working_param )
872
- if pinned_state_dicts and k not in pinned_state_dicts [param_idx ]:
872
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param_idx ]:
873
873
pinned_state_dicts [param_idx ][k ] = torch .empty_like (state_tensor , pin_memory = True , device = "cpu" )
874
- if pinned_state_dicts :
874
+ if pinned_state_dicts is not None :
875
875
pinned_state_dicts [param_idx ][k ].copy_ (state_tensor )
876
876
current_block [k ] = pinned_state_dicts [param_idx ][k ]
877
877
else :
0 commit comments