Skip to content

Commit 870cf8e

Browse files
authored
[Trainer] Support reshard save/load for sharding stage1 (#6633)
1 parent f5f47ad commit 870cf8e

File tree

4 files changed

+719
-100
lines changed

4 files changed

+719
-100
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 126 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
nested_numpify,
116116
nested_truncate,
117117
)
118+
from .utils.sharding_io import ShardingIO
118119

119120
DEFAULT_CALLBACKS = [DefaultFlowCallback]
120121
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
@@ -285,6 +286,9 @@ def __init__(
285286
self.control = TrainerControl()
286287
self._signature_columns = None
287288
self.optimizer_grouped_parameters = None
289+
self.sharding_io = None
290+
if self.args.should_save_sharding_stage1_model or self.args.should_load_sharding_stage1_model:
291+
self.sharding_io = ShardingIO(self.args, self.model, self.optimizer)
288292

289293
if self.sharding is not None and self.optimizer is not None:
290294
raise RuntimeError(
@@ -428,33 +432,52 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
428432
if resume_from_checkpoint is None:
429433
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
430434

431-
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
432-
if isinstance(self.model, LoRAModel):
433-
weight_name = LORA_WEIGHTS_NAME
434-
elif isinstance(self.model, PrefixModelForCausalLM):
435-
weight_name = PREFIX_WEIGHTS_NAME
436-
else:
437-
weight_name = PADDLE_WEIGHTS_NAME
435+
if isinstance(self.model, LoRAModel):
436+
weight_name = LORA_WEIGHTS_NAME
437+
elif isinstance(self.model, PrefixModelForCausalLM):
438+
weight_name = PREFIX_WEIGHTS_NAME
439+
else:
440+
weight_name = PADDLE_WEIGHTS_NAME
438441

439-
if not os.path.isfile(
440-
os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix))
441-
):
442-
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
442+
if self.args.should_load_sharding_stage1_model:
443+
state_dict = self.sharding_io.load_state_dict_from_checkpoint_with_reshard(
444+
resume_from_checkpoint,
445+
base_weight_name=weight_name,
446+
)
447+
else:
448+
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
449+
file_path = os.path.join(
450+
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
451+
)
452+
if not os.path.isfile(file_path):
453+
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")
443454

444-
logger.info(f"Loading model from {resume_from_checkpoint} .")
455+
logger.info(f"Loading model from {resume_from_checkpoint} .")
445456

446-
# We load the model state dict on the CPU to avoid an OOM error.
447-
state_dict = paddle.load(
448-
os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)),
449-
return_numpy=True,
450-
)
451-
# If the model is on the GPU, it still works!
452-
self._set_state_dict_in_model(state_dict)
457+
# We load the model state dict on the CPU to avoid an OOM error.
458+
state_dict = paddle.load(file_path, return_numpy=True)
453459

454-
# release memory
455-
del state_dict
456-
elif resume_from_checkpoint is not None:
457-
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
460+
# If the model is on the GPU, it still works!
461+
self._set_state_dict_in_model(state_dict)
462+
# release memory
463+
del state_dict
464+
elif resume_from_checkpoint is not None:
465+
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
466+
467+
def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
468+
# In the sharded mode, should invoke load_state_dict_from_checkpoint after _wrap_model.
469+
# In this mode, each sharding rank load sharded params, do not need to implement the broadcast logic.
470+
model = self._wrap_model(self.model_wrapped)
471+
if self.sharding_io is not None:
472+
# the self.optimizer should be wrapped and it is done in _wrap_model
473+
self.sharding_io.set_optimizer(self.optimizer)
474+
if model is not self.model:
475+
self.model_wrapped = model
476+
# Should invoke load_state_dict_from_checpoint after _load_optimizer_and_scheduler
477+
# because the load_state_dict_from_checkpoint method rely on the optimizer in the shareded mode.
478+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
479+
self.load_state_dict_from_checkpoint(resume_from_checkpoint)
480+
return model
458481

459482
def train(
460483
self,
@@ -475,43 +498,12 @@ def train(
475498
"""
476499
args = self.args
477500
self.is_in_train = True
478-
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
479501

480502
# memory metrics - must set up as early as possible
481503
self._memory_tracker.start()
482504

483-
# Load potential model checkpoint
484-
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
485-
resume_from_checkpoint = get_last_checkpoint(args.output_dir)
486-
if resume_from_checkpoint is None:
487-
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
488-
489-
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
490-
if isinstance(self.model, LoRAModel):
491-
weight_name = LORA_WEIGHTS_NAME
492-
elif isinstance(self.model, PrefixModelForCausalLM):
493-
weight_name = PREFIX_WEIGHTS_NAME
494-
else:
495-
weight_name = PADDLE_WEIGHTS_NAME
496-
if not os.path.isfile(
497-
os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix))
498-
):
499-
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
500-
501-
logger.info(f"Loading model from {resume_from_checkpoint} .")
502-
503-
# TODO: Need to load the model state dict on the CPU to avoid an OOM error.
504-
state_dict = paddle.load(
505-
os.path.join(resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)),
506-
return_numpy=True,
507-
)
508-
# If the model is on the GPU, it still works!
509-
self._set_state_dict_in_model(state_dict)
510-
511-
# release memory
512-
del state_dict
513-
elif resume_from_checkpoint is not None:
514-
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
505+
if not self.args.should_load_sharding_stage1_model:
506+
self.load_state_dict_from_checkpoint(resume_from_checkpoint)
515507

516508
train_dataloader = self.get_train_dataloader()
517509

@@ -563,17 +555,30 @@ def train(
563555

564556
self.state = TrainerState()
565557

566-
model = self._wrap_model(self.model_wrapped)
567-
568-
# for the rest of this function `model` is the outside model, whether it was wrapped or not
569-
if model is not self.model:
570-
self.model_wrapped = model
571-
572-
if delay_optimizer_creation:
573-
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
574-
575-
# Check if saved optimizer or scheduler states exist
576-
self._load_optimizer_and_scheduler(resume_from_checkpoint)
558+
if self.args.should_load_sharding_stage1_model:
559+
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
560+
elif self.args.should_save_sharding_stage1_model:
561+
# In the non-sharded mode, should invoke load_state_dict_from_checkpoint before _wrap_model.
562+
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
563+
model = self._wrap_model(self.model_wrapped)
564+
if self.sharding_io is not None:
565+
assert delay_optimizer_creation is False, "delay_optimizer_creation should be False"
566+
# the self.optimizer should be wrapped and it is done in _wrap_model
567+
self.sharding_io.set_optimizer(self.optimizer)
568+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
569+
if model is not self.model:
570+
self.model_wrapped = model
571+
if delay_optimizer_creation:
572+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
573+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
574+
else:
575+
model = self._wrap_model(self.model_wrapped)
576+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
577+
if model is not self.model:
578+
self.model_wrapped = model
579+
if delay_optimizer_creation:
580+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
581+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
577582

578583
logger.info("***** Running training *****")
579584
logger.info(f" Num examples = {num_examples:,}")
@@ -1893,12 +1898,26 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
18931898
)
18941899
elif not isinstance(self.model, PretrainedModel):
18951900
if isinstance(unwrap_model(self.model), PretrainedModel):
1896-
unwrap_model(self.model).save_pretrained(
1897-
output_dir,
1898-
merge_tensor_parallel=merge_tensor_parallel,
1899-
variant=self.args.weight_name_suffix,
1900-
is_main_process=self.args.should_save,
1901-
)
1901+
if self.args.should_save_sharding_stage1_model:
1902+
config_to_save = None
1903+
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
1904+
unwrap_model(self.model), merge_tensor_parallel=merge_tensor_parallel
1905+
)
1906+
unwrap_model(self.model).save_pretrained(
1907+
output_dir,
1908+
state_dict=state_dict,
1909+
config_to_save=config_to_save,
1910+
merge_tensor_parallel=merge_tensor_parallel,
1911+
variant=weight_name_suffix,
1912+
is_main_process=self.args.should_save,
1913+
)
1914+
else:
1915+
unwrap_model(self.model).save_pretrained(
1916+
output_dir,
1917+
merge_tensor_parallel=merge_tensor_parallel,
1918+
variant=self.args.weight_name_suffix,
1919+
is_main_process=self.args.should_save,
1920+
)
19021921
else:
19031922
logger.info("Trainer.model is not a `PretrainedModel`, only saving its state dict.")
19041923
if merge_tensor_parallel:
@@ -1910,12 +1929,28 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
19101929
os.path.join(output_dir, _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix)),
19111930
)
19121931
else:
1913-
self.model.save_pretrained(
1914-
output_dir,
1915-
merge_tensor_parallel=merge_tensor_parallel,
1916-
variant=self.args.weight_name_suffix,
1917-
is_main_process=self.args.should_save,
1918-
)
1932+
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
1933+
config_to_save = None
1934+
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
1935+
self.model, merge_tensor_parallel=merge_tensor_parallel
1936+
)
1937+
self.model.save_pretrained(
1938+
output_dir,
1939+
state_dict=state_dict,
1940+
config_to_save=config_to_save,
1941+
merge_tensor_parallel=merge_tensor_parallel,
1942+
variant=weight_name_suffix,
1943+
is_main_process=self.args.should_save,
1944+
)
1945+
else:
1946+
self.model.save_pretrained(
1947+
output_dir,
1948+
merge_tensor_parallel=merge_tensor_parallel,
1949+
variant=self.args.weight_name_suffix,
1950+
is_main_process=self.args.should_save,
1951+
)
1952+
if self.args.should_save_sharding_stage1_model:
1953+
self.sharding_io.save_distributed_model_meta(output_dir)
19191954

19201955
if self.args.should_save:
19211956
if self.tokenizer is not None:
@@ -1929,13 +1964,20 @@ def _load_optimizer_and_scheduler(self, checkpoint):
19291964
if checkpoint is None:
19301965
return
19311966

1932-
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
1967+
opt_state_dict = None
1968+
if self.args.should_load_sharding_stage1_model:
1969+
opt_state_dict = self.sharding_io.load_optimizer_state_with_reshard(
1970+
checkpoint, base_opt_name=OPTIMIZER_NAME
1971+
)
1972+
else:
1973+
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
1974+
path = os.path.join(checkpoint, optimizer_name)
1975+
if os.path.isfile(path):
1976+
opt_state_dict = paddlenlp_load(path, map_location="cpu")
19331977

1934-
if os.path.isfile(os.path.join(checkpoint, optimizer_name)) and os.path.isfile(
1935-
os.path.join(checkpoint, SCHEDULER_NAME)
1936-
):
1978+
if opt_state_dict is not None and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
19371979
# Load in optimizer and scheduler states
1938-
self.optimizer.set_state_dict(paddlenlp_load(os.path.join(checkpoint, optimizer_name), map_location="cpu"))
1980+
self.optimizer.set_state_dict(opt_state_dict)
19391981

19401982
self.lr_scheduler.set_state_dict(paddle.load(os.path.join(checkpoint, SCHEDULER_NAME)))
19411983
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):

paddlenlp/trainer/training_args.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,23 @@ class TrainingArguments:
503503
)
504504
},
505505
)
506+
save_sharded_model: bool = field(
507+
default=False,
508+
metadata={
509+
"help": (
510+
"When use sharding stage1 and set save_sharded_model True, each shanding rank only save part of the model. It reduce time to save the model."
511+
)
512+
},
513+
)
514+
515+
load_sharded_model: bool = field(
516+
default=False,
517+
metadata={
518+
"help": (
519+
"When use sharding stage1 and set load_sharded_model True, it means loading the sharded model. The sharded model is saved when we set save_sharded_model True."
520+
)
521+
},
522+
)
506523
tensor_parallel_degree: int = field(
507524
default=-1,
508525
metadata={
@@ -1057,6 +1074,22 @@ def weight_name_suffix(self):
10571074
else:
10581075
return None
10591076

1077+
def sharded_name_suffix(self, shard_id=None):
1078+
if self.use_hybrid_parallel:
1079+
name = []
1080+
if self.tensor_parallel_degree > 1:
1081+
name.append(f"tp{self.tensor_parallel_rank:0>2d}")
1082+
if self.pipeline_parallel_degree > 1:
1083+
name.append(f"pp{self.pipeline_parallel_rank:0>2d}")
1084+
if self.sharding_parallel_degree > 1:
1085+
if shard_id is None:
1086+
shard_id = self.sharding_parallel_rank
1087+
assert isinstance(shard_id, int)
1088+
name.append(f"shard{shard_id:0>2d}")
1089+
return "_".join(name)
1090+
else:
1091+
return None
1092+
10601093
@property
10611094
def process_index(self):
10621095
"""
@@ -1115,7 +1148,9 @@ def should_save_model_state(self):
11151148
if self.save_on_each_node:
11161149
return self.local_process_index == 0
11171150
else:
1118-
if self.use_hybrid_parallel:
1151+
if self.should_save_sharding_stage1_model:
1152+
return True
1153+
elif self.use_hybrid_parallel:
11191154
# save on dataset rank 0
11201155
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
11211156
else:
@@ -1128,6 +1163,18 @@ def _no_sync_in_gradient_accumulation(self):
11281163
"""
11291164
return True
11301165

1166+
@property
1167+
def should_save_sharding_stage1_model(self):
1168+
return (
1169+
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model
1170+
)
1171+
1172+
@property
1173+
def should_load_sharding_stage1_model(self):
1174+
return (
1175+
ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.load_sharded_model
1176+
)
1177+
11311178
@contextlib.contextmanager
11321179
def main_process_first(self, local=True, desc="work"):
11331180
"""

0 commit comments

Comments
 (0)