Skip to content

optimize reshard #10925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion llm/alignment/rl/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def process_args(model_args: ModelArgument, data_args: DataArgument, training_ar
training_args.max_src_len = data_args.max_prompt_len
training_args.actor_model_name_or_path = model_args.actor_model_name_or_path
training_args.max_length = data_args.max_length
# training_args.hybrid_parallel_topo_order = "mp_first"

if training_args.use_rm_server:
if model_args.reward_server is None:
Expand Down Expand Up @@ -319,11 +320,21 @@ def main():
max_sequence_length=data_args.max_length,
)

# import random
# random.seed(training_args.seed)
# paddle.seed(training_args.seed)
# print(f"Fu set seed:{training_args.seed}")
gather_in_micro_dp = training_args.gather_in_micro_dp # 修改入口
use_export_only_rollout = training_args.use_export_only_rollout # 修改入口
print(f"Fu gather_in_micro_dp:{gather_in_micro_dp}, use_export_only_rollout:{use_export_only_rollout}")
if (
training_args.rollout_tensor_parallel_degree != training_args.tensor_parallel_degree
or training_args.pipeline_parallel_degree > 1
):
reshard_controller = ReshardController(tensor_parallel_degree=training_args.rollout_tensor_parallel_degree)
reshard_controller = ReshardController(train_tensor_parallel_degree=training_args.tensor_parallel_degree,
infer_tensor_parallel_degree=training_args.rollout_tensor_parallel_degree,
gather_in_micro_dp=gather_in_micro_dp
)
else:
reshard_controller = None

Expand Down Expand Up @@ -401,6 +412,8 @@ def compute_metrics(eval_preds):
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)
generation_config=generation_config,
reshard_controller=reshard_controller,
gather_in_micro_dp=gather_in_micro_dp,
use_export_only_rollout=use_export_only_rollout,
)

# TODO(gongenlei) resume_from_checkpoint is not ready
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/datasets/rlhf_datasets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,4 @@ def split_batch_into_micro_batches(self, batch_size, pad_token_id=0) -> List["Da
micro_batches.append(DataProto.from_single_dict(micro_batch))

return micro_batches

1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def concat(tensor_list, axis=-1):
model_prefix = self.base_model_prefix + f".layers.{idx}"
# logger.info(f"set state for layer {idx}")

unfused_state_dict = {}
ln_scale = paddle.to_tensor(state_dict[f"{model_prefix}.input_layernorm.weight"]).cast(
self.transformer_block.ln_scales[idx].dtype
)
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/rl/trainer/actor_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def compute_logprob(self, batch: DataProto, key) -> DataProto:
Raises:
None.
"""
# print(f"Fu compute logprob using model: {type(self.model)} {id(self.model)}")
input_ids = batch.batch["input_ids"]
position_ids = batch.batch["position_ids"]
prompt = batch.batch.get("prompt", None)
Expand Down
159 changes: 150 additions & 9 deletions paddlenlp/rl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None,
generation_config: Optional[GenerationConfig] = None,
reshard_controller: Optional[ReshardController] = None,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -318,6 +319,8 @@ def __init__(
) = self.init_train_num(self.train_dataloader)

args.max_steps = self.max_steps
self.gather_in_micro_dp = kwargs.get("gather_in_micro_dp", False)
self.use_export_only_rollout = kwargs.get("use_export_only_rollout", False)

self.reshard_controller = reshard_controller
trainer_agrs = {
Expand All @@ -334,6 +337,7 @@ def __init__(
"preprocess_logits_for_metrics": preprocess_logits_for_metrics,
}


self.actor_trainer = self.create_actor_trainer(
model=actor_model,
model_eval=actor_model_eval,
Expand Down Expand Up @@ -401,6 +405,7 @@ def __init__(
self.timers.log = types.MethodType(new_timer_log, self.timers)
self.generation_config = generation_config


def create_actor_trainer(
self,
model: Union[PretrainedModel, nn.Layer] = None,
Expand Down Expand Up @@ -432,6 +437,7 @@ def create_actor_trainer(
[None, lr_scheduler],
preprocess_logits_for_metrics,
reshard_controller,
gather_in_micro_dp=self.gather_in_micro_dp,
)
actor_trainer.set_eval_model(model_eval)
actor_trainer.timers = self.timers
Expand Down Expand Up @@ -719,7 +725,7 @@ def prediction_step(
inputs = self._prepare_inputs(inputs)
data_trans_group = getattr(self.actor_trainer, "_data_trans_group", None)
inputs = data_group_split(inputs, group=data_trans_group)
with reload_and_offload_scope(self, self.actor_model, self.reference_model, self.actor_trainer):
with reload_and_offload_scope(self, self.actor_model, self.reference_model, self.actor_trainer, export_only_rollout=True):
with infer_guard(self.actor_trainer):
prompt_only_batch = DataProto.from_single_dict(
{
Expand Down Expand Up @@ -1349,12 +1355,23 @@ def train(

step = -1
for prompt_only_batch_dict in self.prompt_only_dataloader:
timer_batch = TimerScope(
self.timers,
RolloutStages.BATCH,
)
timer_batch.start()

prompt_only_batch: DataProto = DataProto.from_single_dict(prompt_only_batch_dict)
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
# step 1-1: rollout data with actor model (eval) and reward model
self.set_eval()
data_trans_group = getattr(self.actor_trainer, "_data_trans_group", None)

# print(f"Fu data check [before data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'].shape}, {prompt_only_batch.batch['input_ids']._md5sum()}")
# print(f"Fu data check [before data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'][:,-120:-100]}")
prompt_only_batch = data_group_split(prompt_only_batch, group=data_trans_group)
# print(f"Fu data check [after data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'].shape}, {prompt_only_batch.batch['input_ids']._md5sum()}")
# print(f"Fu data check [after data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'][:,-120:-100]}")
eos_token_ids = llm_utils.get_eos_token_id(self.tokenizer, self.generation_config)
pad_token_id = self.tokenizer.pad_token_id
prompt_only_batch.meta_info = {
Expand All @@ -1371,7 +1388,7 @@ def train(
repeat_times=self.args.rollout_n, interleave=True
)
prompt_only_batch_expand.rename("raw_prompt_len", "raw_prompt_len_expand")

expand_prompt = prompt_only_batch_expand.batch["input_ids"]
per_device_rollout_batch_size = self.args.per_device_rollout_batch_size
cleanup_batches, indices, label_ids_batches = [], [], []
Expand All @@ -1381,9 +1398,9 @@ def train(
RolloutStages.ACTOR_MODEL_ENABLE_DISABLE,
minus_names=[RolloutStages.GENERATE],
)

timer_scope_actor_model.start()
with reload_and_offload_scope(self, self.actor_model):
with reload_and_offload_scope(self, self.actor_model, export_only_rollout=True):
timer_scope_rollout = TimerScope(self.timers, RolloutStages.GENERATE)
timer_scope_rollout.start()
with infer_guard(self.actor_trainer):
Expand All @@ -1392,9 +1409,25 @@ def train(
for i in range(0, total_batch_size, per_device_rollout_batch_size):
micro_prompt_batch = prompt_only_batch[i : i + per_device_rollout_batch_size]
# generate for multi batches and then disable FuseMT model

# hcg = fleet.get_hybrid_communicate_group()
# rollout_sharding_parallel_group = hcg.get_sharding_parallel_group()
# rollout_data_parallel_group = hcg.get_data_parallel_group()
# rollout_model_parallel_group = hcg.get_model_parallel_group()
# rollout_pipeline_parallel_group = hcg.get_pipeline_parallel_group()
# print(f"Fu rollout dp group:{rollout_data_parallel_group}, rollout tp group:{rollout_model_parallel_group}, rollout sdp group:{rollout_sharding_parallel_group}") #, rollout pp group:{rollout_pipeline_parallel_group}

# print(f"Fu data check [before generate_sequences] micro_prompt_batch: {micro_prompt_batch.batch['input_ids'].shape}, {micro_prompt_batch.batch['input_ids']._md5sum()}")
# print(f"Fu data check [before generate_sequences] micro_prompt_batch: {micro_prompt_batch.batch['input_ids'][:,300:320]}")
generated_batches: List[DataProto] = self.actor_trainer.generate_sequences(
micro_prompt_batch
)

# for idx,generated_batch in enumerate(generated_batches):
# print(f"Fu generated_batch {idx}, lr padding count:{count_lr_padding(generated_batch.batch['input_ids'])}")
# print(f"Fu data check [after generate_sequences] generated_batch: {generated_batch.batch['input_ids'].shape}, {generated_batch.batch['input_ids']._md5sum()}")
# print(f"Fu data check [after generate_sequences] generated_batch: {generated_batch.batch['input_ids'][:,300:320]}")

# NOTE(drownfish19): do process for each micro_batch, prepare for split mode
micro_ret = self.remove_pad_tokens_after_generate(generated_batches)
micro_cleanup_batches, micro_indices, micro_label_ids_batches = micro_ret
Expand Down Expand Up @@ -1451,20 +1484,24 @@ def train(
}
)

batch = data_group_merge(batch, group=data_trans_group)
batch = data_group_merge(batch, group=data_trans_group, pad_token_id=0) #self.tokenizer.pad_token_id
# print(f"Fu data check [after data group merge] batch: {batch.batch['input_ids'].shape}, {batch.batch['input_ids']._md5sum()} lr padding count:{count_lr_padding(batch.batch['input_ids'])}")
# print(f"Fu data check [after data group merge] batch: {batch.batch['input_ids'][:,300:320]}")


# step 2-2: balance batches based on batch tokens
if self.args.balance_batch:
batch = self._balance_batch(batch)

# step 2-3: compute logprob for rollout data
with self.autocast_smart_context_manager():
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
with reload_and_offload_scope(self, self.reference_model):
with reload_and_offload_scope(self, self.reference_model, export_only_rollout=not self.use_export_only_rollout):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
ref_log_probs = self.reference_trainer.compute_logprob(batch, key="ref_log_probs")
batch = batch.union(ref_log_probs)

with reload_and_offload_scope(self, self.actor_model):
with reload_and_offload_scope(self, self.actor_model, export_only_rollout=not self.use_export_only_rollout):
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
log_probs = self.actor_trainer.compute_logprob(batch, key="log_probs")
batch = batch.union(log_probs)
Expand All @@ -1479,6 +1516,7 @@ def train(
self,
self.critic_model if self.args.rl_algorithm == "ppo" else None,
self.reward_model if not self.args.use_rm_server and not self.args.use_rule_reward else None,
export_only_rollout=not self.use_export_only_rollout
):
with TimerScope(self.timers, RolloutStages.ROLLOUT_REWARD_VALUE):
reward_tensor = self.reward_trainer.compute_reward(
Expand Down Expand Up @@ -1603,7 +1641,7 @@ def train(
# step 3: train actor model and critic model with rollout data
self.set_train()
with TimerScope(self.timers, ActorStages.MODEL_ENABLE_DISABLE, minus_names=[ActorStages.RL_STEP]):
with reload_and_offload_scope(self, self.actor_model, self.actor_trainer.optimizer):
with reload_and_offload_scope(self, self.actor_model, self.actor_trainer.optimizer, export_only_rollout=not self.use_export_only_rollout):
with TimerScope(self.timers, ActorStages.RL_STEP):
# timer_info = {} # prepare for each micro_step
micro_batches = batch.split_batch_into_micro_batches(
Expand All @@ -1617,6 +1655,9 @@ def train(
with TimerScopeManualLabel(
self.timers, get_timer_label(ActorStages.MICRO_STEPS) + f"_{micro_step}"
):

# print(f"Fu data check [before train step {micro_step}] micro_batch: {micro_batch.batch['input_ids'].shape}, {micro_batch.batch['input_ids']._md5sum()} lr count padding:{count_lr_padding(micro_batch.batch['input_ids'])}")
# print(f"Fu data check [before train step {micro_step}] micro_batch: {micro_batch.batch['input_ids'][:,300:320]}")
rl_info = self.actor_trainer.update_actor(micro_batch)

paddle.device.cuda.empty_cache()
Expand All @@ -1642,12 +1683,14 @@ def train(
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

step += 1

timer_batch.stop()

self._print_timer()
self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=micro_batch)
paddle.device.cuda.empty_cache()
if self.control.should_epoch_stop or self.control.should_training_stop:
break


if step < 0:
logger.warning(
Expand Down Expand Up @@ -1935,3 +1978,101 @@ def compute_advantage_normalization(self, batch: DataProto):
batch.batch["reward_advantages"] = batch.batch["reward_advantages"] * batch.batch["eos_mask"]

return batch

def remove_pad_and_count_lr(tensor, pad_id=151643):
"""
输入二维 tensor,去除所有pad_id元素,同时统计每行左、右两侧连续pad_id数量。

Args:
tensor (paddle.Tensor): shape = [batch_size, seq_len]
pad_id (int): 填充id

Returns:
filtered (paddle.Tensor): 一维tensor,去除所有pad_id的有效元素
left_pad_counts (paddle.Tensor): shape=[batch_size], 每行左侧连续pad数量
right_pad_counts (paddle.Tensor): shape=[batch_size], 每行右侧连续pad数量
"""
batch_size, seq_len = tensor.shape
left_pad_counts = []
right_pad_counts = []

# 转numpy计算连续pad数量更方便
tensor_np = tensor.numpy()

for row in tensor_np:
# 左侧连续pad数量
left_count = 0
for val in row:
if val == pad_id:
left_count += 1
else:
break

# 右侧连续pad数量
right_count = 0
for val in row[::-1]:
if val == pad_id:
right_count += 1
else:
break

left_pad_counts.append(left_count)
right_pad_counts.append(right_count)

# 转成tensor
left_pad_counts = paddle.to_tensor(left_pad_counts, dtype='int64')
right_pad_counts = paddle.to_tensor(right_pad_counts, dtype='int64')

# 过滤掉所有pad_id元素,返回一维tensor
mask = tensor != pad_id
filtered = paddle.masked_select(tensor, mask)

# print(f"tensor left padding:{left_pad_counts}, right paddin:{right_pad_counts}")
return filtered, left_pad_counts, right_pad_counts


def count_lr_padding(tensor, pad_id=151643):
"""
统计每行左、右两侧连续pad_id数量。

Args:
tensor (paddle.Tensor): shape = [batch_size, seq_len]
pad_id (int): 填充id

Returns:
filtered (paddle.Tensor): 一维tensor,去除所有pad_id的有效元素
left_pad_counts (paddle.Tensor): shape=[batch_size], 每行左侧连续pad数量
right_pad_counts (paddle.Tensor): shape=[batch_size], 每行右侧连续pad数量
"""
batch_size, seq_len = tensor.shape
left_pad_counts = []
right_pad_counts = []

# 转numpy计算连续pad数量更方便
tensor_np = tensor.numpy()

for row in tensor_np:
# 左侧连续pad数量
left_count = 0
for val in row:
if val == pad_id:
left_count += 1
else:
break

# 右侧连续pad数量
right_count = 0
for val in row[::-1]:
if val == pad_id:
right_count += 1
else:
break

left_pad_counts.append(left_count)
right_pad_counts.append(right_count)

# 转成tensor
left_pad_counts = paddle.to_tensor(left_pad_counts, dtype='int64')
right_pad_counts = paddle.to_tensor(right_pad_counts, dtype='int64')

return left_pad_counts, right_pad_counts
Loading
Loading