@@ -784,27 +784,27 @@ def get_metadata_file_name(path):
784784 metadata = paddle .load (metadata_file )
785785 state_dict_metadata .update (metadata .state_dict_metadata )
786786
787- init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
787+ if not self .args .sharded_model_from_ema :
788+ init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
788789
789- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
790+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
790791
791- opt_states = {}
792- master_weights = {}
793- for k , v in optimizer_sharded_state_dict .items ():
794- if k .endswith (".w_0" ):
795- master_weights [k ] = v
796- else :
797- opt_states [k ] = v
792+ opt_states = {}
793+ master_weights = {}
794+ for k , v in optimizer_sharded_state_dict .items ():
795+ if k .endswith (".w_0" ):
796+ master_weights [k ] = v
797+ else :
798+ opt_states [k ] = v
798799
799- dist .load_state_dict (
800- opt_states ,
801- opt_states_path ,
802- aoa_config = self .args .aoa_config ,
803- offload = self .args .load_via_cpu ,
804- comm_method = self .args .flex_ckpt_comm_method ,
805- )
800+ dist .load_state_dict (
801+ opt_states ,
802+ opt_states_path ,
803+ aoa_config = self .args .aoa_config ,
804+ offload = self .args .load_via_cpu ,
805+ comm_method = self .args .flex_ckpt_comm_method ,
806+ )
806807
807- if not self .args .sharded_model_from_ema :
808808 dist .load_state_dict (
809809 master_weights ,
810810 master_weights_path ,
@@ -819,12 +819,8 @@ def get_metadata_file_name(path):
819819 ema_states_path = os .path .join (resume_from_checkpoint , EMA_STATE_DIC , f"{ dist .get_rank ()} _0.distcp" )
820820 ema_state_dict = paddle .load (ema_states_path )
821821 ema_master_weights = ema_state_dict .pop ("master_weights" , None )
822- opt_master_weights = self .optimizer .state_dict ()["master_weights" ]
823- for k , v in opt_master_weights .items ():
824- assert (
825- k in ema_master_weights
826- ), f"{ k } not in ema_master_weights, emas_master_weight keys { ema_master_weights .keys ()} "
827- paddle .assign (ema_master_weights [k ], opt_master_weights [k ])
822+ opt_state_dict = {"master_weights" : ema_master_weights }
823+ self .optimizer .set_state_dict (opt_state_dict )
828824
829825 self .model .set_state_dict (ema_state_dict )
830826 else :
0 commit comments