Skip to content

Commit ddcb722

Browse files
authored
[RL] logprob compute use the same method (#10596)
* fix logprob compute method * update
1 parent cd2d2dc commit ddcb722

File tree

3 files changed

+187
-17
lines changed

3 files changed

+187
-17
lines changed

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -485,16 +485,14 @@ def forward(
485485
kl_loss_coeff=self.kl_loss_coeff,
486486
loop_chunk_size=1024,
487487
response_start=response_start,
488-
use_actor_fused_loss=self.entropy_coeff <= 0, # currently only support kunbo's fused head loss
488+
use_actor_fused_loss=True, # currently only support kunbo's fused head loss
489489
temperature=self.temperature,
490490
)
491491
with paddle.no_grad():
492492
self.info_buffer["kl_loss"] = (
493493
kl_loss.detach() / self.kl_loss_coeff if self.kl_loss_coeff > 0 else paddle.to_tensor([0.0])
494494
)
495-
self.info_buffer["entropy_loss"] = (
496-
entropy_loss.detach() / self.entropy_coeff if self.entropy_coeff > 0 else paddle.to_tensor([0.0])
497-
)
495+
self.info_buffer["entropy_loss"] = entropy_loss.detach()
498496
self.info_buffer["pure_policy_loss"] = (
499497
pg_loss.detach() / self.pg_loss_coeff if self.pg_loss_coeff > 0 else paddle.to_tensor([0.0])
500498
)
@@ -716,6 +714,7 @@ def forward(
716714
clip_range_score: float,
717715
kl_loss_coeff: float, # KL loss coefficient
718716
temperature: float,
717+
print_entropy_loss: bool = True,
719718
):
720719
"""
721720
forward function of ActorFusedLoss
@@ -813,11 +812,11 @@ def forward(
813812
token_end_idx = min(i + loop_chunk_size, n_tokens)
814813
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
815814
labels_chunk = labels[token_start_idx:token_end_idx]
816-
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx]
815+
mask_chunk = loss_mask[token_start_idx:token_end_idx]
816+
old_log_probs_chunk = old_log_probs[token_start_idx:token_end_idx] * mask_chunk
817817
if kl_loss_coeff > 0:
818-
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx]
818+
ref_log_chunk = ref_log_probs[token_start_idx:token_end_idx] * mask_chunk
819819
advantages_chunk = advantages[token_start_idx:token_end_idx]
820-
mask_chunk = loss_mask[token_start_idx:token_end_idx]
821820

822821
# Calculate the current logits_chunk, not fused linear
823822
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
@@ -841,13 +840,14 @@ def forward(
841840
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
842841
softmax_output_chunk = F.softmax(logits_chunk, axis=-1)
843842

844-
log_probs_chunk = -token_loss_chunk.squeeze(axis=-1)
843+
log_probs_chunk = -token_loss_chunk.squeeze(axis=-1) * mask_chunk
845844
# calculate gradient, note sign
846845
grad_logits_chunk = labels_one_hot.astype("float32") - softmax_output_chunk
847846
grad_logits_chunk = grad_logits_chunk.astype(dtype)
848847

849848
# ratio
850849
ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk)
850+
851851
clipped_ratio_chunk = paddle.clip(
852852
ratio_chunk, min=1.0 - clip_range_ratio_low, max=1.0 + clip_range_ratio_high
853853
)
@@ -892,6 +892,7 @@ def forward(
892892
if kl_loss_coeff > 0:
893893
# [3] kl loss
894894
delta_chunk = ref_log_chunk - log_probs_chunk
895+
895896
exp_delta_chunk = paddle.exp(delta_chunk)
896897
kl_loss_estimate_chunk = exp_delta_chunk - delta_chunk - 1
897898
kl_loss_clipped_chunk = (
@@ -912,6 +913,17 @@ def forward(
912913
)
913914
d_loss_d_logits_chunk += d_kl_log_probs_chunk.unsqueeze(-1) * d_log_probs_d_logits_chunk
914915

916+
if print_entropy_loss:
917+
# [2] entropy loss
918+
log_prob_chunk = paddle.log(paddle.clip(softmax_output_chunk, min=1e-12))
919+
entropy_loss_chunk = -(softmax_output_chunk * log_prob_chunk).sum(axis=-1) * mask_chunk
920+
# entropy_loss_chunk shape is [bs, seqlen, vocab_size // tensor_parallel_degree], do all_reduce sum here
921+
if tensor_parallel_degree > 1 and tensor_parallel_output:
922+
paddle.distributed.all_reduce(
923+
entropy_loss_chunk, op=paddle.distributed.ReduceOp.SUM, group=model_parallel_group
924+
)
925+
total_entropy_loss += entropy_loss_chunk.sum() / divisor
926+
915927
# grads
916928
if grad_hidden_states is not None:
917929
grad_hidden_states[token_start_idx:token_end_idx] = paddle.matmul(

paddlenlp/rl/trainer/actor_trainer.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import numpy as np
1818
import paddle
19+
import paddle.nn.functional as F
20+
from paddle.distributed import fleet
21+
from paddle.distributed.fleet.layers.mpu import mp_ops
1922
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
2023

2124
from ..models.ppo_model_utils import (
@@ -57,6 +60,13 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor
5760
Raises:
5861
None.
5962
"""
63+
if self.args.use_fused_head_and_loss_fn:
64+
return self.compute_fused_logprob(
65+
input_ids=input_ids,
66+
position_ids=position_ids,
67+
**kwargs,
68+
)
69+
6070
log_probs_list = []
6171
batch_size, sequence_length = input_ids.shape
6272
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
@@ -147,6 +157,153 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor
147157

148158
return paddle.concat(log_probs_list, axis=0)
149159

160+
def compute_fused_logprob(
161+
self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, loop_chunk_size=1024, **kwargs
162+
):
163+
log_probs_list = []
164+
batch_size, sequence_length = input_ids.shape
165+
per_device_logprob_batch_size = self.args.per_device_logprob_batch_size
166+
num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size
167+
168+
# Pipe model outputs a logits tensor with LMHead, while non-pipe model
169+
# outputs a tuple with logits tensor as the only one element.
170+
startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id)
171+
response_start = kwargs["prompt"].shape[-1] - 1 if "prompt" in kwargs else 0
172+
173+
num_embeddings = self.model.config.vocab_size
174+
tensor_parallel_degree = self.model.config.tensor_parallel_degree
175+
tensor_parallel_output = self.model.config.tensor_parallel_output
176+
177+
for i in range(num_batches):
178+
# Calculate the start and end indices for the current batch
179+
start_index = i * per_device_logprob_batch_size
180+
end_index = min(start_index + per_device_logprob_batch_size, batch_size)
181+
182+
# Extract the current batch
183+
current_input_ids = input_ids[start_index:end_index]
184+
current_startend_row_indices = (
185+
startend_row_indices[start_index:end_index] if startend_row_indices is not None else None
186+
)
187+
current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None
188+
current_labels = current_input_ids[:, response_start + 1 :]
189+
190+
if self.args.use_remove_padding:
191+
from ..utils.bert_padding import prepare_flashmask_inputs
192+
193+
update_inputs = prepare_flashmask_inputs(
194+
current_input_ids,
195+
current_position_ids,
196+
self.tokenizer.pad_token_id,
197+
self.model.config.sequence_parallel,
198+
self.model.config.tensor_parallel_degree,
199+
)
200+
current_input_ids = update_inputs["input_ids"]
201+
current_position_ids = update_inputs["position_ids"]
202+
current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"]
203+
indices = update_inputs["indices"]
204+
raw_input_shape = update_inputs["raw_input_shape"]
205+
pad_size = update_inputs["pad_size"]
206+
207+
# NOTE: for use_fused_head_and_loss_fn
208+
self.model.training = True
209+
hidden_states, lm_head_weight, lm_head_bias, transpose_y = self.model(
210+
current_input_ids,
211+
position_ids=current_position_ids,
212+
attn_mask_startend_row_indices=current_startend_row_indices,
213+
)
214+
self.model.training = False
215+
216+
if self.args.use_remove_padding:
217+
if pad_size > 0:
218+
hidden_states = hidden_states[:, :-pad_size]
219+
220+
from ..utils.bert_padding import pad_input
221+
222+
hidden_states = pad_input(
223+
hidden_states.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1]
224+
).contiguous()
225+
226+
if self.args.use_fp32_compute and hidden_states.dtype != paddle.float32:
227+
hidden_states = hidden_states.cast(paddle.float32)
228+
lm_head_weight = lm_head_weight.cast(paddle.float32)
229+
if lm_head_bias is not None:
230+
lm_head_bias = lm_head_bias.cast(paddle.float32)
231+
232+
# Recover
233+
hidden_states = hidden_states[:, response_start:-1, :]
234+
dtype = hidden_states.dtype
235+
original_shape = hidden_states.shape
236+
if tensor_parallel_degree > 1:
237+
assert tensor_parallel_output, (
238+
"When tensor_parallel_degree > 1 and use_fused_head_and_loss_fn, "
239+
"tensor_parallel_output needs to be set to True."
240+
)
241+
# Parallel Configuration
242+
if tensor_parallel_degree > 1 and tensor_parallel_output:
243+
hcg = fleet.get_hybrid_communicate_group()
244+
model_parallel_group = hcg.get_model_parallel_group()
245+
tensor_parallel_degree = hcg.get_model_parallel_world_size()
246+
247+
# reshape
248+
hidden_states = hidden_states.reshape([-1, original_shape[-1]])
249+
labels = current_labels.reshape([-1])
250+
251+
n_tokens = hidden_states.shape[0]
252+
n_classes = lm_head_weight.shape[0] if transpose_y else lm_head_weight.shape[1]
253+
254+
# convert dtype of weights and biases of lm_head
255+
lm_head_weight_cast = lm_head_weight.astype(dtype)
256+
if lm_head_bias is not None:
257+
lm_head_bias_cast = lm_head_bias.astype(dtype)
258+
259+
# use indices to distinguish the devices.
260+
if tensor_parallel_degree > 1 and tensor_parallel_output:
261+
rank = hcg.get_model_parallel_rank()
262+
per_part_size = num_embeddings // tensor_parallel_degree
263+
indices = paddle.arange(
264+
rank * per_part_size,
265+
rank * per_part_size + n_classes,
266+
dtype=labels.dtype,
267+
).unsqueeze(0)
268+
else:
269+
indices = paddle.arange(num_embeddings, dtype=labels.dtype).unsqueeze(0)
270+
271+
log_prob_chunks = []
272+
for ci in range(0, n_tokens, loop_chunk_size):
273+
token_start_idx = ci
274+
token_end_idx = min(ci + loop_chunk_size, n_tokens)
275+
hidden_states_chunk = hidden_states[token_start_idx:token_end_idx]
276+
labels_chunk = labels[token_start_idx:token_end_idx]
277+
278+
# Calculate the current logits_chunk, not fused linear
279+
logits_chunk_cast = paddle.matmul(hidden_states_chunk, lm_head_weight_cast, transpose_y=transpose_y)
280+
if lm_head_bias is not None:
281+
logits_chunk_cast += lm_head_bias_cast
282+
283+
logits_chunk = logits_chunk_cast.astype("float32")
284+
logits_chunk = logits_chunk / self.args.temperature
285+
286+
# rewritten as cross entropy
287+
if tensor_parallel_degree > 1 and tensor_parallel_output:
288+
token_loss_chunk = mp_ops._c_softmax_with_cross_entropy(
289+
logits_chunk,
290+
labels_chunk,
291+
group=model_parallel_group,
292+
return_softmax=False,
293+
)
294+
else:
295+
token_loss_chunk = F.cross_entropy(logits_chunk, labels_chunk, reduction="none")
296+
log_prob_chunk = -token_loss_chunk.squeeze(axis=-1)
297+
log_prob_chunks.append(log_prob_chunk)
298+
299+
log_probs = paddle.concat(log_prob_chunks, axis=-1).reshape(original_shape[:-1])
300+
log_probs_list.append(log_probs)
301+
302+
log_prob_chunks = None
303+
paddle.device.cuda.empty_cache()
304+
305+
return paddle.concat(log_probs_list, axis=0)
306+
150307
def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
151308
# inputs shared by policy and value trainer
152309
input_ids = rl_batch["input_ids"].contiguous() # length: src+tgt

paddlenlp/rl/trainer/ppo_trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,15 +1473,16 @@ def train(
14731473
if self.args.balance_batch:
14741474
batch = self._balance_batch(batch)
14751475

1476-
# step 2-3: compute logprob for rollout data
1477-
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
1478-
with reload_and_offload_scope(self, self.reference_model):
1479-
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
1480-
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)
1481-
1482-
with reload_and_offload_scope(self, self.actor_model):
1483-
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
1484-
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)
1476+
with self.autocast_smart_context_manager():
1477+
# step 2-3: compute logprob for rollout data
1478+
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
1479+
with reload_and_offload_scope(self, self.reference_model):
1480+
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
1481+
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)
1482+
1483+
with reload_and_offload_scope(self, self.actor_model):
1484+
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
1485+
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)
14851486

14861487
# step 2-2: compute reward for rollout data
14871488
with TimerScope(

0 commit comments

Comments
 (0)