Skip to content

Commit 68645cc

Browse files
Cherry-pick useful PRs from fleety (#11051)
* [LLM] fix normal sl (#10609) * [LLM] support disable monkey patch (#10617) * [LLM] use safer grad sync method when enabling sp (#10714) * [RL] disable aistudio download and fix qwen bug (#10819) * [RL] fix qwen load when fuse is enabled and modify gate presion to fp32 (#10842) * 修复qwen3moe的一系列报错(justin) (#10818) * [LLM] zcc support rng states (#10430) (#10485) * [DLTP-85730] optimize save cost for zcc * [LLM] fix moe using on tensor parallelism * cherry-pick useful PRs from fleety and fix some typos --------- Co-authored-by: aiyinyuedejustin <[email protected]>
1 parent e050fd8 commit 68645cc

File tree

11 files changed

+403
-148
lines changed

11 files changed

+403
-148
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
DefaultFlowCallback,
141141
PrinterCallback,
142142
ProgressCallback,
143+
SPGradSyncCallback,
143144
TrainerCallback,
144145
TrainerControl,
145146
TrainerState,
@@ -444,9 +445,8 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
444445
), "should_save_sharding_stage1_model should be True when using zero cost checkpoint"
445446
assert (
446447
ShardingOption.FULL_SHARD not in self.args.sharding
447-
), "FULL_SHARD is not supported when using zero cost checkpoint"
448-
assert not self.args.save_tokenizer, "save_tokenizer is not supported when using zero cost checkpoint"
449-
assert not self.args.save_rng_states, "save_rng_states is not supported when using zero cost checkpoint"
448+
), "FULL_SHARD is not supported when using flash save mode"
449+
assert not self.args.save_tokenizer, "save_tokenizer is not supported when using flash save mode"
450450

451451
# init attributes for zero cost checkpoint mode
452452
self.zcc_manager = None
@@ -2021,34 +2021,18 @@ def _load_rng_state(self, checkpoint):
20212021
if checkpoint is None:
20222022
return
20232023

2024-
# if use distributed training
2025-
if self.args.world_size > 1:
2026-
process_index = self.args.process_index
2027-
rng_file_list = [None for x in range(self.args.world_size)]
2028-
if self.args.should_save:
2029-
rng_file = os.path.join(checkpoint, f"rng_state_{self.args.world_size}.pth")
2030-
if os.path.isfile(rng_file):
2031-
rng_file_list = paddle.load(rng_file, return_numpy=True)
2032-
paddle.distributed.broadcast_object_list(rng_file_list, src=0)
2033-
# if rng_file_list still empty, not log rng state.
2034-
if rng_file_list[0] is None:
2035-
logger.info(
2036-
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2037-
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
2038-
)
2039-
return
2040-
else:
2041-
checkpoint_rng_state = rng_file_list[process_index]
2042-
else:
2043-
rng_file = os.path.join(checkpoint, "rng_state.pth")
2044-
if not os.path.isfile(rng_file):
2045-
logger.info(
2046-
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
2047-
"fashion, reproducibility is not guaranteed."
2048-
)
2049-
return
2024+
rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth")
2025+
if not os.path.isfile(rng_file):
2026+
logger.info(
2027+
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
2028+
"fashion, reproducibility is not guaranteed."
2029+
)
2030+
return
20502031

2051-
checkpoint_rng_state = paddle.load(rng_file, return_numpy=True)
2032+
checkpoint_rng_state = paddle.load(rng_file, return_numpy=True)
2033+
if checkpoint_rng_state.get("world_size", None) != self.args.world_size:
2034+
logger.warn("Cannot load rng states when changing world size of training job.")
2035+
return
20522036

20532037
random.setstate(checkpoint_rng_state["python"])
20542038
np.random.set_state(checkpoint_rng_state["numpy"])
@@ -2210,11 +2194,6 @@ def _wrap_model(self, model, training=True):
22102194
else:
22112195
model, self.optimizer = decorated
22122196

2213-
if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
2214-
register_sequence_parallel_allreduce_hooks(
2215-
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
2216-
)
2217-
22182197
if self.args.world_size == 1:
22192198
if self.args.amp_master_grad:
22202199
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
@@ -2403,6 +2382,17 @@ def get_expected_keys(inputs, keys):
24032382
):
24042383
self.optimizer._set_broadcast_overlap(True, model)
24052384

2385+
# use callback for sp grad sync in case of unexpected behaviour (except sharding stage 2&3)
2386+
if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
2387+
if ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.FULL_SHARD in self.args.sharding:
2388+
register_sequence_parallel_allreduce_hooks(
2389+
unwrap_model(model),
2390+
self.args.gradient_accumulation_steps,
2391+
self.args.fuse_sequence_parallel_allreduce,
2392+
)
2393+
else:
2394+
self.add_callback(SPGradSyncCallback(model._layers))
2395+
24062396
return model
24072397

24082398
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
@@ -2739,28 +2729,24 @@ def _save_checkpoint(self, model, metrics=None):
27392729
if self.args.should_save:
27402730
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
27412731

2742-
# Save RNG state in non-distributed training
2743-
rng_states = {
2744-
"python": random.getstate(),
2745-
"numpy": np.random.get_state(),
2746-
"cuda": paddle.get_rng_state(),
2747-
"cpu": paddle.framework.core.default_cpu_generator().get_state(),
2748-
}
2749-
if self.args.use_hybrid_parallel:
2750-
rng_states[
2751-
"hybrid_parallel_rng_state_tracker"
2752-
] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
2732+
if self.args.save_rng_states:
2733+
# Save RNG state in non-distributed training
2734+
rng_states = {
2735+
"python": random.getstate(),
2736+
"numpy": np.random.get_state(),
2737+
"cuda": paddle.get_rng_state(),
2738+
"cpu": paddle.framework.core.default_cpu_generator().get_state(),
2739+
"world_size": self.args.world_size,
2740+
}
2741+
if self.args.use_hybrid_parallel:
2742+
rng_states[
2743+
"hybrid_parallel_rng_state_tracker"
2744+
] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
27532745

27542746
if self.args.save_rng_states:
2755-
if self.args.world_size > 1:
2756-
rng_states_list = []
2757-
paddle.distributed.all_gather_object(rng_states_list, rng_states)
2758-
if self.args.should_save:
2759-
os.makedirs(output_dir, exist_ok=True)
2760-
paddle.save(rng_states_list, os.path.join(output_dir, f"rng_state_{self.args.world_size}.pth"))
2761-
else:
2762-
os.makedirs(output_dir, exist_ok=True)
2763-
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
2747+
rng_state_file = os.path.join(output_dir, f"rng_state_{dist.get_rank()}.pth")
2748+
os.makedirs(output_dir, exist_ok=True)
2749+
paddle.save(rng_states, rng_state_file)
27642750

27652751
# only save model state dict, ignore optimizer and scheduler
27662752
if not self.args.ignore_save_lr_and_optim:

paddlenlp/trainer/trainer_callback.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,18 @@
2020
"""
2121
import dataclasses
2222
import json
23+
import time
2324
from dataclasses import dataclass
2425
from typing import Dict, List, Optional, Union
2526

2627
import numpy as np
28+
from paddle.distributed.fleet import fleet
29+
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
30+
fused_allreduce_gradients_with_group,
31+
)
32+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
33+
is_sequence_parallel_parameter,
34+
)
2735
from tqdm.auto import tqdm
2836

2937
from paddlenlp.utils.log import logger
@@ -609,3 +617,33 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
609617
self.check_metric_value(args, state, control, metric_value)
610618
if self.early_stopping_patience_counter >= self.early_stopping_patience:
611619
control.should_training_stop = True
620+
621+
622+
class SPGradSyncCallback(TrainerCallback):
623+
"""
624+
SPGradSyncCallback
625+
只能在非 sharding stage2 的情况下使用。
626+
开启sharding stage2 时,在 `on_optimizer_begin` 的时候 grad 已经被清空了
627+
"""
628+
629+
def __init__(self, model):
630+
assert hasattr(fleet, "_hcg"), "must use MP when calling this Callback"
631+
logger.info("using sp callback")
632+
params = []
633+
self.model = model
634+
for n, p in model.named_parameters():
635+
if is_sequence_parallel_parameter(p):
636+
logger.info(f"register bw hook for:{n}")
637+
params.append(p)
638+
639+
logger.info(f"#-sp-sync param:{len(params)}")
640+
self._sp_params = params
641+
642+
def on_optimizer_begin(self, args, state, control, **kwargs):
643+
"""on_optimizer_begin"""
644+
if self._sp_params:
645+
now = time.time()
646+
mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
647+
fused_allreduce_gradients_with_group(self._sp_params, group=mp_group, scale=1.0) # sum not mean
648+
another_time = time.time()
649+
logger.info(f"sync gradients takes {another_time - now} time")

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
import json
1919
import multiprocessing
2020
import os
21+
import random
2122
import time
2223
from collections import OrderedDict
2324
from enum import Enum
2425

26+
import numpy as np
2527
import paddle
2628
import paddle.autograd as imperative_base
2729
import paddle.distributed as dist
@@ -414,10 +416,26 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
414416
self.maybe_update_zcc_worker(args, model, optimizer, state.global_step)
415417
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
416418
save_infos = self._get_save_infos_based_on_steps(state, args, checkpoint_folder)
417-
non_cached_objects = (lr_scheduler.state_dict(), copy.deepcopy(state))
419+
non_cached_objects = (lr_scheduler.state_dict(), state, self.get_rng_states(args))
418420
self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects))
419421
self.runtime_timer.stop()
420422

423+
def get_rng_states(self, args):
424+
if not args.save_rng_states:
425+
return None
426+
rng_states = {
427+
"python": random.getstate(),
428+
"numpy": np.random.get_state(),
429+
"cuda": paddle.get_rng_state(),
430+
"cpu": paddle.framework.core.default_cpu_generator().get_state(),
431+
"world_size": args.world_size,
432+
}
433+
if args.use_hybrid_parallel:
434+
rng_states[
435+
"hybrid_parallel_rng_state_tracker"
436+
] = dist.fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
437+
return rng_states
438+
421439
def _get_save_infos_based_on_steps(self, state, args, checkpoint_folder):
422440
flash_device_checkpoint_dir = None
423441
persistent_checkpoint_dir = None
@@ -701,6 +719,7 @@ def __init__(
701719
# TODO(@gexiao): remove lr scheduler saves
702720
self.lr_scheduler = None
703721
self.trainer_state = None
722+
self.rng_state = None
704723

705724
# for dumping
706725
self.flash_device_save_dir = None
@@ -734,7 +753,7 @@ def process_prepare_task(self, prepares):
734753
return
735754
save_infos, non_cached_objects = prepares
736755
self.flash_device_save_dir, self.persistent_save_dir = save_infos
737-
self.lr_scheduler, self.trainer_state = non_cached_objects
756+
self.lr_scheduler, self.trainer_state, self.rng_state = non_cached_objects
738757

739758
def process_offload_task(self, dump, global_step):
740759
"""
@@ -897,6 +916,11 @@ def process_dump_task_impl(self, output_dir):
897916
if self.device_id == 0:
898917
self.trainer_state.save_to_json(trainer_state_name_path)
899918

919+
# Step2.5: save RNG State
920+
if self.rng_state is not None:
921+
rng_state_name_path = os.path.join(output_dir, f"rng_state_{dist.get_rank()}.pth")
922+
paddle.save(self.rng_state, rng_state_name_path)
923+
900924
# Step3: dump save signals
901925
saved_signal_path = os.path.join(output_dir, f"saved_signal_{self.global_rank}")
902926
with open(saved_signal_path, mode="w+") as f:

paddlenlp/transformers/moe_gate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def top1gating(
326326
logits += self.gumbel_rsample(logits.shape)
327327

328328
gates = self.gate_score_func(logits=logits)
329+
329330
capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity)
330331

331332
# Create a mask for 1st's expert per token
@@ -396,6 +397,7 @@ def top2gating(
396397
logits: paddle.Tensor,
397398
) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
398399
# everything is in fp32 in this function
400+
399401
gates = self.gate_score_func(logits=logits)
400402

401403
# Create a mask for 1st's expert per token.

0 commit comments

Comments
 (0)