Skip to content

Commit 57e9210

Browse files
authored
hotfix entropy calculation (#6364)
1 parent 4cf5ce2 commit 57e9210

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250250
input_ids_forward_micro_batch = data["input_ids"][
251251
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252252
]
253+
old_action_log_probs_micro_batch = old_action_log_probs[
254+
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255+
]
253256
attention_mask_forward_micro_batch = data["attention_mask"][
254257
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255258
]
@@ -306,17 +309,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
306309
"action_mask": action_mask_forward_micro_batch,
307310
"advantages": advantages_forward_micro_batch,
308311
"loss_mask": loss_mask_forward_micro_batch,
312+
"old_action_log_probs": old_action_log_probs_micro_batch,
309313
"source": self.rank,
310314
}
311315
if reference_action_log_probs is not None:
312316
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
313317

314318
kl = []
315-
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
316319

317320
def _criterion(outputs, inputs):
318321
action_logits = outputs.logits
319-
policy_model_logits.copy_(action_logits)
322+
mini_batch_entropies.append(
323+
(
324+
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
325+
/ inputs["action_mask"].sum(-1)
326+
).detach()
327+
)
320328
action_log_probs = memory_efficient_logprob(
321329
action_logits / self.generate_config["temperature"],
322330
inputs["input_ids"],
@@ -339,7 +347,7 @@ def _criterion(outputs, inputs):
339347

340348
loss, _ = self.policy_loss_fn(
341349
action_log_probs,
342-
action_log_probs,
350+
inputs["old_action_log_probs"],
343351
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
344352
per_token_kl,
345353
inputs["action_mask"],
@@ -363,20 +371,6 @@ def _criterion(outputs, inputs):
363371
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
364372
mean_kl.append(kl)
365373
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
366-
mini_batch_entropies.append(
367-
all_reduce_mean(
368-
(
369-
(
370-
(
371-
entropy_from_logits(policy_model_logits[:, -num_action:])
372-
* action_mask_forward_micro_batch
373-
).sum(-1)
374-
)
375-
/ action_mask_forward_micro_batch.sum(-1)
376-
).detach(),
377-
self.plugin,
378-
)
379-
)
380374
else:
381375
policy_model_logits = self.policy_model(
382376
input_ids=input_ids_forward_micro_batch,
@@ -415,7 +409,7 @@ def _criterion(outputs, inputs):
415409

416410
loss, _ = self.policy_loss_fn(
417411
action_log_probs,
418-
old_action_log_probs,
412+
old_action_log_probs_micro_batch,
419413
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
420414
per_token_kl,
421415
action_mask_forward_micro_batch,
@@ -455,7 +449,7 @@ def _criterion(outputs, inputs):
455449
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
456450
advantages = all_reduce_mean(advantages.mean(), self.plugin)
457451
response_length = all_reduce_mean(response_length.mean(), self.plugin)
458-
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
452+
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
459453
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
460454
self.accum_entropy.add_(entropy.data)
461455
if self.policy_loss_fn.beta > 0:

0 commit comments

Comments
 (0)