Skip to content

Commit 32b9a2e

Browse files
decouple sharding io and non zcc (#11201)
1 parent 51f73a5 commit 32b9a2e

File tree

2 files changed

+103
-18
lines changed

2 files changed

+103
-18
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,16 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
11751175
logger.info("Zero cost checkpoint manager created successfully.")
11761176

11771177
def add_non_zcc_ema_callback(self, resume_from_checkpoint):
1178-
self.add_callback(NonZCCEMACallback(resume_from_checkpoint, self.args, self.sharding_io))
1178+
non_zcc_ema_callback = NonZCCEMACallback.create_nonzcc_callback(
1179+
args=self.args,
1180+
resume_from_checkpoint=resume_from_checkpoint,
1181+
sharding_io=self.sharding_io,
1182+
model=self.model,
1183+
optimizer=self.optimizer,
1184+
hcg=self.hcg,
1185+
)
1186+
1187+
self.add_callback(non_zcc_ema_callback)
11791188

11801189
def train(
11811190
self,

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import random
2323
import time
24+
from abc import ABC, abstractmethod
2425
from collections import OrderedDict, defaultdict
2526
from dataclasses import replace
2627
from enum import Enum
@@ -1142,28 +1143,21 @@ def manage_offload_chunk(self):
11421143
)
11431144

11441145

1145-
class EMABuffer:
1146-
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
1147-
assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True"
1146+
class EMABuffer(ABC):
1147+
def __init__(self, resume_from_checkpoint, args, offload=True):
11481148
self.master_weights = {}
11491149
self.model_params = {}
11501150
self.args = args
1151-
self.sharding_io = sharding_io
11521151
self.offload = offload
11531152
if resume_from_checkpoint is not None:
11541153
self._load(resume_from_checkpoint)
11551154

1156-
def _ema_path(self, base_path):
1157-
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
1158-
path = path.replace("optimizer", "ema")
1159-
return os.path.join(base_path, path)
1160-
11611155
def _load(self, resume_from_checkpoint):
11621156
ema_path = self._ema_path(resume_from_checkpoint)
11631157
if not os.path.exists(ema_path):
11641158
return
11651159

1166-
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
1160+
success, err_msg = self._check_consistent_dist_strategy(resume_from_checkpoint)
11671161
if not success:
11681162
logger.info(f"Cannot load EMA because: {err_msg}")
11691163
return
@@ -1190,14 +1184,11 @@ def ema_accumulate(self, global_step, loss, ema_loss_threshold):
11901184
if ema_loss_threshold is None or loss < ema_loss_threshold:
11911185
logger.info(f"EMA accumulating for step {global_step} ...")
11921186
self._ema_impl(
1193-
state_dict=self.sharding_io.optimizer.state_dict()["master_weights"],
1187+
state_dict=self._get_master_weight(),
11941188
ema_state_dict=self.master_weights,
11951189
)
11961190
self._ema_impl(
1197-
state_dict=self.sharding_io.manipulate_state_dict_and_config(
1198-
unwrap_model(self.sharding_io.model),
1199-
merge_tensor_parallel=False,
1200-
)[0],
1191+
state_dict=self._get_model_state(),
12011192
ema_state_dict=self.model_params,
12021193
)
12031194
logger.info(f"EMA accumulate done for step {global_step}")
@@ -1218,10 +1209,95 @@ def _ema_impl(self, state_dict, ema_state_dict):
12181209
v = v_pin
12191210
ema_state_dict[k] = v
12201211

1212+
@abstractmethod
1213+
def _get_master_weight(self):
1214+
pass
12211215

1222-
class NonZCCEMACallback(TrainerCallback):
1216+
@abstractmethod
1217+
def _get_model_state(self):
1218+
pass
1219+
1220+
@abstractmethod
1221+
def _check_consistent_dist_strategy(self, resume_from_checkpoint):
1222+
pass
1223+
1224+
1225+
class EMABufferShardingIOBased(EMABuffer):
12231226
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
1224-
self.buffer = EMABuffer(resume_from_checkpoint, args, sharding_io, offload)
1227+
assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True"
1228+
self.sharding_io = sharding_io
1229+
super().__init__(resume_from_checkpoint, args, offload)
1230+
1231+
def _ema_path(self, base_path):
1232+
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
1233+
path = path.replace("optimizer", "ema")
1234+
return os.path.join(base_path, path)
1235+
1236+
def _get_model_state(self):
1237+
return self.sharding_io.manipulate_state_dict_and_config(
1238+
unwrap_model(self.sharding_io.model),
1239+
merge_tensor_parallel=False,
1240+
)[0]
1241+
1242+
def _get_master_weight(self):
1243+
return self.sharding_io.optimizer.state_dict()["master_weights"]
1244+
1245+
def _check_consistent_dist_strategy(self, resume_from_checkpoint):
1246+
return self.sharding_io.check_same_strategy(resume_from_checkpoint)
1247+
1248+
1249+
class EMABufferFcBased(EMABuffer):
1250+
def __init__(self, resume_from_checkpoint, args, offload=True, hcg=None, model=None, optimizer=None):
1251+
self.hcg = hcg
1252+
self.model = model
1253+
self.optimizer = optimizer
1254+
self.dist_info_collector_and_validator = DistInfoCollectorValidator(args, hcg)
1255+
1256+
super().__init__(resume_from_checkpoint, args, offload)
1257+
1258+
def _get_model_meta(self):
1259+
return self.dist_info_collector_and_validator.gather_distributed_model_meta(self.model, self.optimizer)
1260+
1261+
def _ema_path(self, base_path):
1262+
return os.path.join(base_path, "ema_state", f"{dist.get_rank()}_0.distcp")
1263+
1264+
def _check_consistent_dist_strategy(self, resume_from_checkpoint):
1265+
return self.dist_info_collector_and_validator.check_same_strategy(os.path.dirname(resume_from_checkpoint))
1266+
1267+
def _get_model_state(self):
1268+
assert self.model is not None, "expected model is not None"
1269+
return self.model.state_dict()
1270+
1271+
def _get_master_weight(self):
1272+
assert self.optimizer is not None, "expected optimizer is not None"
1273+
return self.optimizer.state_dict()["master_weights"]
1274+
1275+
def save(self, global_step):
1276+
model_meta_content = self._get_model_meta()
1277+
model_meta_path = os.path.join(self.args.output_dir, MODEL_META_NAME)
1278+
with open(model_meta_path, "w") as f:
1279+
json.dump(model_meta_content, f)
1280+
1281+
super().save(global_step)
1282+
1283+
1284+
class NonZCCEMACallback(TrainerCallback):
1285+
def __init__(self, ema_buffer: EMABuffer):
1286+
self.buffer = ema_buffer
1287+
1288+
@staticmethod
1289+
def create_nonzcc_callback(
1290+
args, resume_from_checkpoint, sharding_io=None, model=None, optimizer=None, hcg=None, offload=True
1291+
):
1292+
if args.save_checkpoint_format == "flex_checkpoint":
1293+
ema_buffer = EMABufferFcBased(
1294+
resume_from_checkpoint, args, offload=offload, hcg=hcg, model=model, optimizer=optimizer
1295+
)
1296+
else:
1297+
assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True"
1298+
ema_buffer = EMABufferShardingIOBased(resume_from_checkpoint, args, sharding_io, offload=offload)
1299+
1300+
return NonZCCEMACallback(ema_buffer)
12251301

12261302
def on_step_end(self, args, state, control, **kwargs):
12271303
if state.global_step % args.zcc_ema_interval == 0:

0 commit comments

Comments
 (0)