Skip to content

Commit b0276d7

Browse files
authored
【FlexCheckpoint】adapter flex checkpoint (#11188)
* adapter flex checkpoint * opt bf16 save * fix multi moe_sharding_group load safetensors * fix
1 parent fa007c8 commit b0276d7

File tree

5 files changed

+507
-121
lines changed

5 files changed

+507
-121
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 253 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,13 @@
121121
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
122122
from ..utils.env import (
123123
DISLORA_WEIGHTS_NAME,
124+
EMA_STATE_DIC,
124125
LOKR_WEIGHTS_NAME,
125126
LORA_WEIGHTS_NAME,
127+
MASTER_WEIGHT_DIC,
126128
MODEL_META_NAME,
129+
MODEL_STATE_DIC,
130+
OPTIMIZER_STATE_DIC,
127131
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
128132
PADDLE_OPTIMIZER_NAME,
129133
PADDLE_PEFT_WEIGHTS_INDEX_NAME,
@@ -185,6 +189,8 @@
185189
from .unified_checkpoint import UnifiedCheckpointHandler
186190
from .utils import reshard as reshard_util
187191
from .utils.async_save import AsyncSaver
192+
from .utils.reshard import SHARDING_STRATEGY_V1, split_opt_state
193+
from .utils.sharding_io import GroupGetter, to_device
188194

189195
try:
190196
from .utils.zero_cost_checkpoint import (
@@ -673,6 +679,248 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
673679
elif resume_from_checkpoint is not None:
674680
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
675681

682+
def _load_flex_checkpoint(self, resume_from_checkpoint):
683+
def get_metadata_file_name(path):
684+
files = os.listdir(path)
685+
metadata_files = [f for f in files if f.endswith(".metadata")]
686+
assert len(metadata_files) > 0, f"Found no metadata files in {path}"
687+
assert len(metadata_files) == 1, f"Found multiple metadata files in {path}"
688+
return metadata_files[0]
689+
690+
model_sharded_state_dict = self.model.sharded_state_dict()
691+
hf_aoa_config = self.model._gen_aoa_config(self.model.config)
692+
master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC)
693+
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
694+
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)
695+
696+
if self.args.load_from_hf:
697+
hcg = dist.fleet.get_hybrid_communicate_group()
698+
assert (
699+
self.args.ignore_load_lr_and_optim
700+
), "Loading from HuggingFace format is only allowed when learning rate and optimizer state are ignored."
701+
try:
702+
moe_sharding_group = hcg.get_moe_sharding_parallel_group()
703+
except Exception:
704+
moe_sharding_group = None
705+
706+
if moe_sharding_group is None or moe_sharding_group.nranks <= 1:
707+
# when moe_sharding_group is None, we use the default process_group
708+
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
709+
dist.load_state_dict(
710+
model_sharded_state_dict,
711+
resume_from_checkpoint,
712+
aoa_config=hf_aoa_config,
713+
offload=self.args.load_via_cpu,
714+
safetensors=True,
715+
process_group=None,
716+
comm_method=self.args.comm_method,
717+
)
718+
else:
719+
try:
720+
pp_group = hcg.get_pipe_parallel_group()
721+
if pp_group is None or pp_group.nranks < 1:
722+
raise NotImplementedError("Only support when pp_group is not None.")
723+
except Exception:
724+
raise RuntimeError("Only support when pp_group is not None.")
725+
726+
try:
727+
moe_group = hcg.get_expert_parallel_group()
728+
if moe_group is None or moe_group.nranks < 1:
729+
raise NotImplementedError("Only support when moe_group is not None.")
730+
except Exception:
731+
raise RuntimeError("Only support when moe_group is not None.")
732+
moe_sharding_rank = moe_sharding_group.rank
733+
cur_rank = dist.get_rank()
734+
if moe_sharding_rank == 0:
735+
moe_group_ranks = []
736+
dist.all_gather_object(moe_group_ranks, cur_rank, group=moe_group)
737+
pp_group_ranks = []
738+
dist.all_gather_object(pp_group_ranks, moe_group_ranks, group=pp_group)
739+
process_group_ranks = [rank for ranks in pp_group_ranks for rank in ranks]
740+
else:
741+
process_group_ranks = [0] * (pp_group.nranks * moe_group.nranks)
742+
src_rank = hcg.get_moe_sharding_parallel_group_src_rank()
743+
dist.broadcast_object_list(process_group_ranks, src=src_rank, group=moe_sharding_group)
744+
assert any(process_group_ranks), "process_group_ranks should not be all 0"
745+
logger.info(f"Creating a temporary process group with ranks: {process_group_ranks}")
746+
process_group = dist.new_group(process_group_ranks)
747+
748+
if moe_sharding_rank == 0:
749+
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
750+
# Only the first moe_sharding process is allowed to load the model weights.
751+
dist.load_state_dict(
752+
model_sharded_state_dict,
753+
resume_from_checkpoint,
754+
aoa_config=hf_aoa_config,
755+
offload=self.args.load_via_cpu,
756+
safetensors=True,
757+
process_group=process_group,
758+
comm_method=self.args.comm_method,
759+
)
760+
761+
dist.barrier()
762+
logger.info("Destroying the temporary process group.")
763+
dist.destroy_process_group(process_group)
764+
# The first moe_sharding group loads the model weights and then broadcasts them to all other moe_sharding groups.
765+
logger.info(
766+
"First shard (moe_sharding_group) has loaded safetensors weights, starting broadcast on moe_sharding_groups."
767+
)
768+
for param_name, param in self.model.state_dict().items():
769+
dist.broadcast(param, src=src_rank, group=moe_sharding_group)
770+
logger.info("Safetensors format weights have been loaded successfully.")
771+
return
772+
773+
if not self.args.ignore_load_lr_and_optim:
774+
state_dict_metadata = {}
775+
metadata_paths = [
776+
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
777+
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
778+
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
779+
]
780+
781+
for metadata_file in metadata_paths:
782+
if not os.path.exists(metadata_file):
783+
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
784+
metadata = paddle.load(metadata_file)
785+
state_dict_metadata.update(metadata.state_dict_metadata)
786+
787+
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
788+
789+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
790+
791+
opt_states = {}
792+
master_weights = {}
793+
for k, v in optimizer_sharded_state_dict.items():
794+
if k.endswith(".w_0"):
795+
master_weights[k] = v
796+
else:
797+
opt_states[k] = v
798+
799+
dist.load_state_dict(
800+
opt_states,
801+
opt_states_path,
802+
aoa_config=self.args.aoa_config,
803+
offload=self.args.load_via_cpu,
804+
comm_method=self.args.comm_method,
805+
)
806+
807+
if not self.args.sharded_model_from_ema:
808+
dist.load_state_dict(
809+
master_weights,
810+
master_weights_path,
811+
aoa_config=self.args.aoa_config,
812+
offload=self.args.load_via_cpu,
813+
)
814+
815+
self._load_scheduler(resume_from_checkpoint)
816+
817+
should_load_stage1 = self.args.should_load_sharding_stage1_model
818+
if should_load_stage1 and self.args.sharded_model_from_ema:
819+
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
820+
ema_state_dict = paddle.load(ema_states_path)
821+
ema_master_weights = ema_state_dict.pop("master_weights", None)
822+
opt_master_weights = self.optimizer.state_dict()["master_weights"]
823+
for k, v in opt_master_weights.items():
824+
assert (
825+
k in ema_master_weights
826+
), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}"
827+
paddle.assign(ema_master_weights[k], opt_master_weights[k])
828+
829+
ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group)
830+
self.model.set_state_dict(ema_state_dict)
831+
else:
832+
dist.load_state_dict(
833+
model_sharded_state_dict,
834+
model_states_path,
835+
aoa_config=self.args.aoa_config,
836+
offload=self.args.load_via_cpu,
837+
)
838+
839+
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:
840+
opt_state_dict = self.optimizer.state_dict()
841+
842+
def recover_params_from_master_weight(opt_state_dict, group):
843+
master_weights = opt_state_dict["master_weights"]
844+
tmp = OrderedDict()
845+
(master_weights, tmp) = (tmp, master_weights)
846+
# cast to before
847+
for (k, v) in tmp.items():
848+
name = v.name
849+
master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu()
850+
master_weights[k].name = name
851+
852+
structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()}
853+
node_model_state = reshard_util.NodeModelState(group=group)
854+
node_model_state_tmp = reshard_util.NodeModelState(group=group)
855+
node_model_state_tmp.add_master_weights(master_weights)
856+
node_model_state_tmp.pack_keys(structure_name_map)
857+
node_model_state.merge_from(node_model_state_tmp, max(group.rank, 0))
858+
del node_model_state_tmp
859+
sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)
860+
logger.debug(f"sharding_strategy: {sharding_strategy}")
861+
restore_func = (
862+
reshard_util.sharding_v1.restore
863+
if sharding_strategy == SHARDING_STRATEGY_V1
864+
else reshard_util.sharding_v2.restore
865+
)
866+
node_model_state = restore_func(node_model_state, self.model, self.optimizer)
867+
node_model_state.unpack_keys()
868+
master_weights = node_model_state.master_weights
869+
870+
master_weights = reshard_util.all_gather_state_dict(master_weights, lambda x: True, group)
871+
872+
model_state_dict = self.model.state_dict()
873+
for key, param in model_state_dict.items():
874+
if param.name in master_weights:
875+
logger.debug(
876+
f"key {key}, convert master weights {param.name} shape {master_weights[param.name].shape} to param {param.name} shape{param.shape}"
877+
)
878+
assert (
879+
param.shape == master_weights[param.name].shape
880+
), f"got {param.shape} vs {master_weights[param.name].shape}"
881+
master_weight = paddle.reshape(master_weights[param.name], param.shape)
882+
paddle.assign(paddle.cast(to_device(master_weight), paddle.bfloat16), model_state_dict[key])
883+
884+
group_getter = GroupGetter(self.model)
885+
opt_state_dict = split_opt_state(opt_state_dict, group_getter)
886+
for gid in group_getter.get_group_ids():
887+
sub_opt_state_dict = opt_state_dict[gid]
888+
group = group_getter.get_group_by_id(gid)
889+
if self.args.bf16:
890+
recover_params_from_master_weight(sub_opt_state_dict, group)
891+
892+
def _save_flex_model_state(self, output_dir):
893+
model_sharded_state_dict = self.model.sharded_state_dict()
894+
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
895+
os.makedirs(model_state_dict_path, exist_ok=True)
896+
dist.save_state_dict(
897+
model_sharded_state_dict,
898+
model_state_dict_path,
899+
)
900+
901+
def _save_flex_optimizer_state(self, output_dir):
902+
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
903+
optimizer_states = {}
904+
master_weights = {}
905+
model_sharded_state_dict = self.model.sharded_state_dict()
906+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
907+
for k, v in optimizer_sharded_state_dict.items():
908+
if k.endswith(".w_0"):
909+
master_weights[k] = v
910+
else:
911+
optimizer_states[k] = v
912+
913+
dist.save_state_dict(
914+
optimizer_states,
915+
optimizer_state_dict_path,
916+
)
917+
918+
master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
919+
dist.save_state_dict(
920+
master_weights,
921+
master_weights_path,
922+
)
923+
676924
def _load_from_checkpoint(self, resume_from_checkpoint=None):
677925
"""load state_dict from_checkpoint, Only load model state dict.
678926
@@ -1048,27 +1296,7 @@ def train(
10481296
if delay_optimizer_creation:
10491297
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
10501298

1051-
if resume_from_checkpoint is not None:
1052-
if not self.args.ignore_load_lr_and_optim:
1053-
model_sharded_state_dict = self.model.sharded_state_dict()
1054-
accessible_files = os.listdir(resume_from_checkpoint)
1055-
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
1056-
assert len(metadata_files) == 1, "Only support one metadata file now."
1057-
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
1058-
state_dict_metadata = metadata.state_dict_metadata
1059-
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
1060-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
1061-
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
1062-
dist.load_state_dict(
1063-
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
1064-
)
1065-
self._load_scheduler(resume_from_checkpoint)
1066-
else:
1067-
model_sharded_state_dict = self.model.sharded_state_dict()
1068-
sharded_state_dict = model_sharded_state_dict
1069-
dist.load_state_dict(
1070-
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
1071-
)
1299+
self._load_flex_checkpoint(resume_from_checkpoint)
10721300
else:
10731301
model = self._wrap_model(self.model_wrapped)
10741302
# for the rest of this function `model` is the outside model, whether it was wrapped or not
@@ -2867,8 +3095,7 @@ def _save_checkpoint(self, model, metrics=None):
28673095
self.save_model(output_dir)
28683096

28693097
if self.args.save_checkpoint_format == "flex_checkpoint":
2870-
model_sharded_state_dict = self.model.sharded_state_dict()
2871-
os.makedirs(output_dir, exist_ok=True)
3098+
self._save_flex_model_state(output_dir)
28723099

28733100
# Determine the new best metric / best model checkpoint
28743101
if metrics is not None and self.args.metric_for_best_model is not None:
@@ -2932,11 +3159,7 @@ def _save_checkpoint(self, model, metrics=None):
29323159
)
29333160
else:
29343161
if self.args.save_checkpoint_format == "flex_checkpoint":
2935-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2936-
dist.save_state_dict(
2937-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2938-
output_dir,
2939-
)
3162+
self._save_flex_optimizer_state(output_dir)
29403163
if self.args.should_save:
29413164
if self.tokenizer is not None and self.args.save_tokenizer:
29423165
self.tokenizer.save_pretrained(output_dir)
@@ -2992,11 +3215,7 @@ def _save_checkpoint(self, model, metrics=None):
29923215
signal_dir,
29933216
)
29943217
elif self.args.save_checkpoint_format == "flex_checkpoint":
2995-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2996-
dist.save_state_dict(
2997-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2998-
output_dir,
2999-
)
3218+
self._save_flex_optimizer_state(output_dir)
30003219
if self.args.should_save:
30013220
if self.tokenizer is not None and self.args.save_tokenizer:
30023221
self.tokenizer.save_pretrained(output_dir)
@@ -3039,10 +3258,7 @@ def _save_checkpoint(self, model, metrics=None):
30393258
self._offload_optimizer()
30403259
else:
30413260
if self.args.save_checkpoint_format == "flex_checkpoint":
3042-
dist.save_state_dict(
3043-
model_sharded_state_dict,
3044-
output_dir,
3045-
)
3261+
self._save_flex_model_state(output_dir)
30463262
if self.args.should_save:
30473263
if self.tokenizer is not None and self.args.save_tokenizer:
30483264
self.tokenizer.save_pretrained(output_dir)

0 commit comments

Comments
 (0)