@@ -776,7 +776,9 @@ def pack_group(group):
776
776
777
777
return {"state" : packed_state , "param_groups" : param_groups }
778
778
779
- def state_dict (self , pinned_state_dicts : Optional [Dict [str , Dict [str , torch .Tensor ]]] = None ) -> Dict :
779
+ def state_dict (
780
+ self , pinned_state_dicts : Optional [Dict [str , Dict [str , torch .Tensor ]]] = None , only_on_master : bool = False
781
+ ) -> Dict :
780
782
"""Return a state_dict same with DDP
781
783
782
784
Returns:
@@ -785,23 +787,29 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
785
787
zero_state = dict ()
786
788
device = get_accelerator ().get_current_device ()
787
789
for param , state in self .optim .state .items ():
790
+ working_param = self .master_to_working_param [id (param )]
791
+ pg = self .param_to_pg [working_param ]
792
+ if not only_on_master or get_nd_rank (pg ) == 0 :
793
+ zero_state [param ] = copy .deepcopy (state )
794
+ else :
795
+ zero_state [param ] = {}
796
+
788
797
if pinned_state_dicts is not None and param not in pinned_state_dicts :
789
798
pinned_state_dicts [param ] = {}
790
- zero_state [ param ] = copy . deepcopy ( state )
799
+
791
800
for k , v in state .items ():
792
801
if isinstance (v , torch .Tensor ) and k != "step" :
793
- working_param = self .master_to_working_param [id (param )]
794
- pg = self .param_to_pg [working_param ]
795
802
gathered_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
796
803
all_gather_into_flat_tensor_nd (gathered_tensor , v .to (device ).flatten (), pg )
797
804
param_state = gathered_tensor [: working_param .numel ()].reshape_as (working_param )
798
- if pinned_state_dicts is not None and k not in pinned_state_dicts [param ]:
799
- pinned_state_dicts [param ][k ] = torch .empty_like (param_state , pin_memory = True , device = "cpu" )
800
- if pinned_state_dicts is not None :
801
- pinned_state_dicts [param ][k ].copy_ (param_state )
802
- zero_state [param ][k ] = pinned_state_dicts [param ][k ]
803
- else :
804
- zero_state [param ][k ] = param_state .cpu ()
805
+ if not only_on_master or get_nd_rank (pg ) == 0 :
806
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param ]:
807
+ pinned_state_dicts [param ][k ] = torch .empty_like (param_state , pin_memory = True , device = "cpu" )
808
+ if pinned_state_dicts is not None :
809
+ pinned_state_dicts [param ][k ].copy_ (param_state )
810
+ zero_state [param ][k ] = pinned_state_dicts [param ][k ]
811
+ else :
812
+ zero_state [param ][k ] = param_state .cpu ()
805
813
806
814
states_dict = self ._pack_state (zero_state )
807
815
@@ -837,7 +845,10 @@ def load_state_dict(self, state_dict: Dict):
837
845
self .optim .load_state_dict (zero_state_dict )
838
846
839
847
def state_dict_shard (
840
- self , max_shard_size : int = 1024 , pinned_state_dicts : Optional [Dict [str , Dict [str , torch .Tensor ]]] = None
848
+ self ,
849
+ max_shard_size : int = 1024 ,
850
+ pinned_state_dicts : Optional [Dict [str , Dict [str , torch .Tensor ]]] = None ,
851
+ only_on_master : bool = False ,
841
852
) -> Iterator [Tuple [Dict , int ]]:
842
853
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
843
854
Only include the 'state' in state_dict.
@@ -862,25 +873,31 @@ def state_dict_shard(
862
873
cnt += 1
863
874
for param_idx , states in local_states .items ():
864
875
current_block_size = 0
865
- current_block = copy .deepcopy (states )
866
876
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts :
867
877
pinned_state_dicts [param_idx ] = {}
868
878
master_param = idx2master [param_idx ]
869
879
working_param = self .master_to_working_param [id (master_param )]
870
880
pg = self .param_to_pg [working_param ]
881
+ if not only_on_master or get_nd_rank (pg ) == 0 :
882
+ current_block = copy .deepcopy (states )
883
+ else :
884
+ current_block = {}
871
885
872
886
for k , v in states .items ():
873
887
if isinstance (v , torch .Tensor ) and k != "step" :
874
888
state_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
875
889
all_gather_into_flat_tensor_nd (state_tensor , v .to (device ).flatten (), pg )
876
890
state_tensor = state_tensor [: working_param .numel ()].reshape_as (working_param )
877
- if pinned_state_dicts is not None and k not in pinned_state_dicts [param_idx ]:
878
- pinned_state_dicts [param_idx ][k ] = torch .empty_like (state_tensor , pin_memory = True , device = "cpu" )
879
- if pinned_state_dicts is not None :
880
- pinned_state_dicts [param_idx ][k ].copy_ (state_tensor )
881
- current_block [k ] = pinned_state_dicts [param_idx ][k ]
882
- else :
883
- current_block [k ] = state_tensor .cpu ()
891
+ if not only_on_master or get_nd_rank (pg ) == 0 :
892
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param_idx ]:
893
+ pinned_state_dicts [param_idx ][k ] = torch .empty_like (
894
+ state_tensor , pin_memory = True , device = "cpu"
895
+ )
896
+ if pinned_state_dicts is not None :
897
+ pinned_state_dicts [param_idx ][k ].copy_ (state_tensor )
898
+ current_block [k ] = pinned_state_dicts [param_idx ][k ]
899
+ else :
900
+ current_block [k ] = state_tensor .cpu ()
884
901
current_block_size += calculate_tensor_size (state_tensor )
885
902
886
903
if ret_block_size + current_block_size > max_shard_size and len (ret_block ) > 0 :
0 commit comments