@@ -786,30 +786,36 @@ def state_dict(
786
786
"""
787
787
zero_state = dict ()
788
788
device = get_accelerator ().get_current_device ()
789
- for param , state in self .optim .state .items ():
790
- working_param = self .master_to_working_param [id (param )]
791
- pg = self .param_to_pg [working_param ]
792
- if not only_on_master or get_nd_rank (pg ) == 0 :
793
- zero_state [param ] = copy .deepcopy (state )
794
- else :
795
- zero_state [param ] = {}
796
-
797
- if pinned_state_dicts is not None and param not in pinned_state_dicts :
798
- pinned_state_dicts [param ] = {}
799
-
800
- for k , v in state .items ():
801
- if isinstance (v , torch .Tensor ) and k != "step" :
802
- gathered_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
803
- all_gather_into_flat_tensor_nd (gathered_tensor , v .to (device ).flatten (), pg )
804
- param_state = gathered_tensor [: working_param .numel ()].reshape_as (working_param )
805
- if not only_on_master or get_nd_rank (pg ) == 0 :
806
- if pinned_state_dicts is not None and k not in pinned_state_dicts [param ]:
807
- pinned_state_dicts [param ][k ] = torch .empty_like (param_state , pin_memory = True , device = "cpu" )
808
- if pinned_state_dicts is not None :
809
- pinned_state_dicts [param ][k ].copy_ (param_state )
810
- zero_state [param ][k ] = pinned_state_dicts [param ][k ]
811
- else :
812
- zero_state [param ][k ] = param_state .cpu ()
789
+ for param_group in self .optim .param_groups :
790
+ for param in param_group ["params" ]:
791
+ if param not in self .optim .state :
792
+ continue
793
+ state = self .optim .state [param ]
794
+ working_param = self .master_to_working_param [id (param )]
795
+ pg = self .param_to_pg [working_param ]
796
+ if not only_on_master or get_nd_rank (pg ) == 0 :
797
+ zero_state [param ] = copy .deepcopy (state )
798
+ else :
799
+ zero_state [param ] = {}
800
+
801
+ if pinned_state_dicts is not None and param not in pinned_state_dicts :
802
+ pinned_state_dicts [param ] = {}
803
+
804
+ for k , v in state .items ():
805
+ if isinstance (v , torch .Tensor ) and k != "step" :
806
+ gathered_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
807
+ all_gather_into_flat_tensor_nd (gathered_tensor , v .to (device ).flatten (), pg )
808
+ param_state = gathered_tensor [: working_param .numel ()].reshape_as (working_param )
809
+ if not only_on_master or get_nd_rank (pg ) == 0 :
810
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param ]:
811
+ pinned_state_dicts [param ][k ] = torch .empty_like (
812
+ param_state , pin_memory = True , device = "cpu"
813
+ )
814
+ if pinned_state_dicts is not None :
815
+ pinned_state_dicts [param ][k ].copy_ (param_state )
816
+ zero_state [param ][k ] = pinned_state_dicts [param ][k ]
817
+ else :
818
+ zero_state [param ][k ] = param_state .cpu ()
813
819
814
820
states_dict = self ._pack_state (zero_state )
815
821
@@ -865,48 +871,52 @@ def state_dict_shard(
865
871
device = get_accelerator ().get_current_device ()
866
872
local_states = self .optim .state_dict ()["state" ]
867
873
868
- idx2master = {}
874
+ master2idx = {}
869
875
cnt = 0
870
876
for param_group in self .optim .param_groups :
871
877
for param in param_group ["params" ]:
872
- idx2master [ cnt ] = param
878
+ master2idx [ param ] = cnt
873
879
cnt += 1
874
- for param_idx , states in local_states .items ():
875
- current_block_size = 0
876
- if pinned_state_dicts is not None and param_idx not in pinned_state_dicts :
877
- pinned_state_dicts [param_idx ] = {}
878
- master_param = idx2master [param_idx ]
879
- working_param = self .master_to_working_param [id (master_param )]
880
- pg = self .param_to_pg [working_param ]
881
- if not only_on_master or get_nd_rank (pg ) == 0 :
882
- current_block = copy .deepcopy (states )
883
- else :
884
- current_block = {}
885
-
886
- for k , v in states .items ():
887
- if isinstance (v , torch .Tensor ) and k != "step" :
888
- state_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
889
- all_gather_into_flat_tensor_nd (state_tensor , v .to (device ).flatten (), pg )
890
- state_tensor = state_tensor [: working_param .numel ()].reshape_as (working_param )
891
- if not only_on_master or get_nd_rank (pg ) == 0 :
892
- if pinned_state_dicts is not None and k not in pinned_state_dicts [param_idx ]:
893
- pinned_state_dicts [param_idx ][k ] = torch .empty_like (
894
- state_tensor , pin_memory = True , device = "cpu"
895
- )
896
- if pinned_state_dicts is not None :
897
- pinned_state_dicts [param_idx ][k ].copy_ (state_tensor )
898
- current_block [k ] = pinned_state_dicts [param_idx ][k ]
899
- else :
900
- current_block [k ] = state_tensor .cpu ()
901
- current_block_size += calculate_tensor_size (state_tensor )
902
-
903
- if ret_block_size + current_block_size > max_shard_size and len (ret_block ) > 0 :
904
- yield ret_block , ret_block_size
905
- ret_block = dict ()
906
- ret_block_size = 0
907
880
908
- ret_block [param_idx ] = current_block
909
- ret_block_size += current_block_size
881
+ for param_group in self .optim .param_groups :
882
+ for master_param in param_group ["params" ]:
883
+ param_idx = master2idx [master_param ]
884
+ states = local_states [param_idx ]
885
+
886
+ current_block_size = 0
887
+ if pinned_state_dicts is not None and param_idx not in pinned_state_dicts :
888
+ pinned_state_dicts [param_idx ] = {}
889
+ working_param = self .master_to_working_param [id (master_param )]
890
+ pg = self .param_to_pg [working_param ]
891
+ if not only_on_master or get_nd_rank (pg ) == 0 :
892
+ current_block = copy .deepcopy (states )
893
+ else :
894
+ current_block = {}
895
+
896
+ for k , v in states .items ():
897
+ if isinstance (v , torch .Tensor ) and k != "step" :
898
+ state_tensor = torch .empty (v .numel () * get_nd_world_size (pg ), device = device , dtype = v .dtype )
899
+ all_gather_into_flat_tensor_nd (state_tensor , v .to (device ).flatten (), pg )
900
+ state_tensor = state_tensor [: working_param .numel ()].reshape_as (working_param )
901
+ if not only_on_master or get_nd_rank (pg ) == 0 :
902
+ if pinned_state_dicts is not None and k not in pinned_state_dicts [param_idx ]:
903
+ pinned_state_dicts [param_idx ][k ] = torch .empty_like (
904
+ state_tensor , pin_memory = True , device = "cpu"
905
+ )
906
+ if pinned_state_dicts is not None :
907
+ pinned_state_dicts [param_idx ][k ].copy_ (state_tensor )
908
+ current_block [k ] = pinned_state_dicts [param_idx ][k ]
909
+ else :
910
+ current_block [k ] = state_tensor .cpu ()
911
+ current_block_size += calculate_tensor_size (state_tensor )
912
+
913
+ if ret_block_size + current_block_size > max_shard_size and len (ret_block ) > 0 :
914
+ yield ret_block , ret_block_size
915
+ ret_block = dict ()
916
+ ret_block_size = 0
917
+
918
+ ret_block [param_idx ] = current_block
919
+ ret_block_size += current_block_size
910
920
911
921
yield ret_block , ret_block_size
912
922
0 commit comments