@@ -220,36 +220,34 @@ def ema_state_dict(self):
220
220
ema_state_dict [k ] = tensor
221
221
ema_state_dict_master_weights = {}
222
222
for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
223
- t = self .ema_buffer . _slice (
224
- meta [ "start" ] - self . master_min_offset , meta ["end" ] - self .master_min_offset
225
- ).clone ()
223
+ s = meta [ "start" ] - self .master_min_offset
224
+ e = meta ["end" ] - self .master_min_offset
225
+ t = self . ema_buffer . _slice ( s , e ).clone ()
226
226
t .get_tensor ()._set_dims (meta ["shape" ])
227
227
t .name = meta ["name" ]
228
228
ema_state_dict_master_weights [k ] = t
229
229
ema_state_dict ["master_weights" ] = ema_state_dict_master_weights
230
230
return ema_state_dict
231
231
232
- def load_ema_state_dict (self , path ):
233
- with device_guard ("cpu" ):
234
- logger .info (f"[ZCC EMA] load state dict from { path } " )
235
- state_dict = paddle .load (path )
236
- for k , tensor_meta in self .param_fusion_storage_helper .model_weights_metas .items ():
237
- logger .info (f"[ZCC EMA] load model weight key={ k } " )
238
- start = tensor_meta ["start" ]
239
- end = tensor_meta ["end" ]
240
- if tensor_meta ["buffer_index" ] not in self .ema_buffer_model_params :
241
- continue # non fp32 has no `self.ema_buffer_model_params`
232
+ def load_ema_state_dict (self , state_dict ):
233
+ for k , tensor_meta in self .param_fusion_storage_helper .model_weights_metas .items ():
234
+ logger .info (f"[ZCC EMA] load model weight key={ k } " )
235
+ start = tensor_meta ["start" ]
236
+ end = tensor_meta ["end" ]
237
+ if tensor_meta ["buffer_index" ] not in self .ema_buffer_model_params :
238
+ continue # non fp32 has no `self.ema_buffer_model_params`
239
+ if k in state_dict :
242
240
cpu_buffer = self .ema_buffer_model_params [tensor_meta ["buffer_index" ]]
243
241
tensor = state_dict [k ].flatten ()
244
242
cpu_buffer [start :end ] = tensor
245
243
246
- ema_master = state_dict ["master_weights" ]
247
- for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
248
- logger .info (f"[ZCC EMA] load optimizer weight key={ k } " )
249
- s = meta ["start" ] - self .master_min_offset
250
- e = meta ["end" ] - self .master_min_offset
251
- self . ema_buffer [ s : e ] = ema_master [ k ]
252
- logger . info ( "[ZCC EMA] done loading" )
244
+ ema_master = state_dict ["master_weights" ]
245
+ for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
246
+ logger .info (f"[ZCC EMA] load optimizer weight key={ k } " )
247
+ s = meta ["start" ] - self .master_min_offset
248
+ e = meta ["end" ] - self .master_min_offset
249
+ if k in ema_master : # state-dict is filtered
250
+ self . ema_buffer [ s : e ] = ema_master [ k ]. flatten ( )
253
251
254
252
255
253
class ParamFusionStorageHelper :
@@ -408,11 +406,6 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
408
406
logger .info ("[ZCC manager] Synced checkpoints." )
409
407
410
408
def on_step_end (self , args , state , control , model , lr_scheduler , optimizer , ** kwargs ):
411
- if not isinstance (model , PipelineLayer ):
412
- self .manager .zcc_pipeline_hook (0 )
413
- # logger.info(
414
- # f"check coef: {args.zcc_save_ema_coef} {control.should_save}, {state.global_step}, {self.zcc_ema_interval}"
415
- # )
416
409
if not control .should_save :
417
410
if args .zcc_save_ema_coef is not None and state .global_step % self .zcc_ema_interval == 0 :
418
411
self .maybe_update_zcc_worker (args , model , optimizer , state .global_step )
@@ -425,6 +418,8 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
425
418
non_cached_objects = (lr_scheduler .state_dict (), state , self .get_rng_states (args ))
426
419
self .manager .get_idle_worker_for_saving ((save_infos , non_cached_objects ))
427
420
self .runtime_timer .stop ()
421
+ if not isinstance (model , PipelineLayer ):
422
+ self .manager .zcc_pipeline_hook (0 )
428
423
429
424
def get_rng_states (self , args ):
430
425
if not args .save_rng_states :
@@ -959,7 +954,15 @@ def run(self):
959
954
self .optimizer_fusion_storage_helper , self .param_fusion_storage_helper , self .ema_coef
960
955
)
961
956
if ema_ckpt_path is not None : # update ema if needed
962
- self .zcc_ema_processor .load_ema_state_dict (ema_ckpt_path )
957
+ logger .info (f"[ZCC EMA] load state dict from { ema_ckpt_path } " )
958
+ with device_guard ("cpu" ):
959
+ state_dict = paddle .load (ema_ckpt_path )
960
+ if self .use_expert_parallel and self .dp_rank > 0 :
961
+ state_dict = self ._filter_moe_no_sync_optimizer_params (
962
+ self .model_meta_content , state_dict
963
+ )
964
+ self .zcc_ema_processor .load_ema_state_dict (state_dict )
965
+ logger .info ("[ZCC EMA] done loading" )
963
966
ema_ckpt_path = None
964
967
elif task_type == ZCCTaskType .PREPARE :
965
968
start_time = time .time ()
0 commit comments