2121import os
2222import random
2323import time
24+ from abc import ABC , abstractmethod
2425from collections import OrderedDict , defaultdict
2526from dataclasses import replace
2627from enum import Enum
@@ -1142,28 +1143,21 @@ def manage_offload_chunk(self):
11421143 )
11431144
11441145
1145- class EMABuffer :
1146- def __init__ (self , resume_from_checkpoint , args , sharding_io , offload = True ):
1147- assert sharding_io is not None , "EMA should be only enabled when save_sharded_model is True"
1146+ class EMABuffer (ABC ):
1147+ def __init__ (self , resume_from_checkpoint , args , offload = True ):
11481148 self .master_weights = {}
11491149 self .model_params = {}
11501150 self .args = args
1151- self .sharding_io = sharding_io
11521151 self .offload = offload
11531152 if resume_from_checkpoint is not None :
11541153 self ._load (resume_from_checkpoint )
11551154
1156- def _ema_path (self , base_path ):
1157- path = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
1158- path = path .replace ("optimizer" , "ema" )
1159- return os .path .join (base_path , path )
1160-
11611155 def _load (self , resume_from_checkpoint ):
11621156 ema_path = self ._ema_path (resume_from_checkpoint )
11631157 if not os .path .exists (ema_path ):
11641158 return
11651159
1166- success , err_msg = self .sharding_io . check_same_strategy (resume_from_checkpoint )
1160+ success , err_msg = self ._check_consistent_dist_strategy (resume_from_checkpoint )
11671161 if not success :
11681162 logger .info (f"Cannot load EMA because: { err_msg } " )
11691163 return
@@ -1190,14 +1184,11 @@ def ema_accumulate(self, global_step, loss, ema_loss_threshold):
11901184 if ema_loss_threshold is None or loss < ema_loss_threshold :
11911185 logger .info (f"EMA accumulating for step { global_step } ..." )
11921186 self ._ema_impl (
1193- state_dict = self .sharding_io . optimizer . state_dict ()[ "master_weights" ] ,
1187+ state_dict = self ._get_master_weight () ,
11941188 ema_state_dict = self .master_weights ,
11951189 )
11961190 self ._ema_impl (
1197- state_dict = self .sharding_io .manipulate_state_dict_and_config (
1198- unwrap_model (self .sharding_io .model ),
1199- merge_tensor_parallel = False ,
1200- )[0 ],
1191+ state_dict = self ._get_model_state (),
12011192 ema_state_dict = self .model_params ,
12021193 )
12031194 logger .info (f"EMA accumulate done for step { global_step } " )
@@ -1218,10 +1209,95 @@ def _ema_impl(self, state_dict, ema_state_dict):
12181209 v = v_pin
12191210 ema_state_dict [k ] = v
12201211
1212+ @abstractmethod
1213+ def _get_master_weight (self ):
1214+ pass
12211215
1222- class NonZCCEMACallback (TrainerCallback ):
1216+ @abstractmethod
1217+ def _get_model_state (self ):
1218+ pass
1219+
1220+ @abstractmethod
1221+ def _check_consistent_dist_strategy (self , resume_from_checkpoint ):
1222+ pass
1223+
1224+
1225+ class EMABufferShardingIOBased (EMABuffer ):
12231226 def __init__ (self , resume_from_checkpoint , args , sharding_io , offload = True ):
1224- self .buffer = EMABuffer (resume_from_checkpoint , args , sharding_io , offload )
1227+ assert sharding_io is not None , "EMA should be only enabled when save_sharded_model is True"
1228+ self .sharding_io = sharding_io
1229+ super ().__init__ (resume_from_checkpoint , args , offload )
1230+
1231+ def _ema_path (self , base_path ):
1232+ path = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
1233+ path = path .replace ("optimizer" , "ema" )
1234+ return os .path .join (base_path , path )
1235+
1236+ def _get_model_state (self ):
1237+ return self .sharding_io .manipulate_state_dict_and_config (
1238+ unwrap_model (self .sharding_io .model ),
1239+ merge_tensor_parallel = False ,
1240+ )[0 ]
1241+
1242+ def _get_master_weight (self ):
1243+ return self .sharding_io .optimizer .state_dict ()["master_weights" ]
1244+
1245+ def _check_consistent_dist_strategy (self , resume_from_checkpoint ):
1246+ return self .sharding_io .check_same_strategy (resume_from_checkpoint )
1247+
1248+
1249+ class EMABufferFcBased (EMABuffer ):
1250+ def __init__ (self , resume_from_checkpoint , args , offload = True , hcg = None , model = None , optimizer = None ):
1251+ self .hcg = hcg
1252+ self .model = model
1253+ self .optimizer = optimizer
1254+ self .dist_info_collector_and_validator = DistInfoCollectorValidator (args , hcg )
1255+
1256+ super ().__init__ (resume_from_checkpoint , args , offload )
1257+
1258+ def _get_model_meta (self ):
1259+ return self .dist_info_collector_and_validator .gather_distributed_model_meta (self .model , self .optimizer )
1260+
1261+ def _ema_path (self , base_path ):
1262+ return os .path .join (base_path , "ema_state" , f"{ dist .get_rank ()} _0.distcp" )
1263+
1264+ def _check_consistent_dist_strategy (self , resume_from_checkpoint ):
1265+ return self .dist_info_collector_and_validator .check_same_strategy (os .path .dirname (resume_from_checkpoint ))
1266+
1267+ def _get_model_state (self ):
1268+ assert self .model is not None , "expected model is not None"
1269+ return self .model .state_dict ()
1270+
1271+ def _get_master_weight (self ):
1272+ assert self .optimizer is not None , "expected optimizer is not None"
1273+ return self .optimizer .state_dict ()["master_weights" ]
1274+
1275+ def save (self , global_step ):
1276+ model_meta_content = self ._get_model_meta ()
1277+ model_meta_path = os .path .join (self .args .output_dir , MODEL_META_NAME )
1278+ with open (model_meta_path , "w" ) as f :
1279+ json .dump (model_meta_content , f )
1280+
1281+ super ().save (global_step )
1282+
1283+
1284+ class NonZCCEMACallback (TrainerCallback ):
1285+ def __init__ (self , ema_buffer : EMABuffer ):
1286+ self .buffer = ema_buffer
1287+
1288+ @staticmethod
1289+ def create_nonzcc_callback (
1290+ args , resume_from_checkpoint , sharding_io = None , model = None , optimizer = None , hcg = None , offload = True
1291+ ):
1292+ if args .save_checkpoint_format == "flex_checkpoint" :
1293+ ema_buffer = EMABufferFcBased (
1294+ resume_from_checkpoint , args , offload = offload , hcg = hcg , model = model , optimizer = optimizer
1295+ )
1296+ else :
1297+ assert sharding_io is not None , "EMA should be only enabled when save_sharded_model is True"
1298+ ema_buffer = EMABufferShardingIOBased (resume_from_checkpoint , args , sharding_io , offload = offload )
1299+
1300+ return NonZCCEMACallback (ema_buffer )
12251301
12261302 def on_step_end (self , args , state , control , ** kwargs ):
12271303 if state .global_step % args .zcc_ema_interval == 0 :
0 commit comments