Skip to content

Commit c782976

Browse files
committed
hotfix entropy calculation (#6364)
1 parent 3d9dd34 commit c782976

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
263263
input_ids_forward_micro_batch = data["input_ids"][
264264
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
265265
]
266+
old_action_log_probs_micro_batch = old_action_log_probs[
267+
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268+
]
266269
attention_mask_forward_micro_batch = data["attention_mask"][
267270
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
268271
]
@@ -319,17 +322,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
319322
"action_mask": action_mask_forward_micro_batch,
320323
"advantages": advantages_forward_micro_batch,
321324
"loss_mask": loss_mask_forward_micro_batch,
325+
"old_action_log_probs": old_action_log_probs_micro_batch,
322326
"source": self.rank,
323327
}
324328
if reference_action_log_probs is not None:
325329
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
326330

327331
kl = []
328-
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
329332

330333
def _criterion(outputs, inputs):
331334
action_logits = outputs.logits
332-
policy_model_logits.copy_(action_logits)
335+
mini_batch_entropies.append(
336+
(
337+
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
338+
/ inputs["action_mask"].sum(-1)
339+
).detach()
340+
)
333341
action_log_probs = memory_efficient_logprob(
334342
action_logits / self.generate_config["temperature"],
335343
inputs["input_ids"],
@@ -352,7 +360,7 @@ def _criterion(outputs, inputs):
352360

353361
loss, _ = self.policy_loss_fn(
354362
action_log_probs,
355-
action_log_probs,
363+
inputs["old_action_log_probs"],
356364
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
357365
per_token_kl,
358366
inputs["action_mask"],
@@ -376,20 +384,6 @@ def _criterion(outputs, inputs):
376384
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
377385
mean_kl.append(kl)
378386
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
379-
mini_batch_entropies.append(
380-
all_reduce_mean(
381-
(
382-
(
383-
(
384-
entropy_from_logits(policy_model_logits[:, -num_action:])
385-
* action_mask_forward_micro_batch
386-
).sum(-1)
387-
)
388-
/ action_mask_forward_micro_batch.sum(-1)
389-
).detach(),
390-
self.plugin,
391-
)
392-
)
393387
else:
394388
policy_model_logits = self.policy_model(
395389
input_ids=input_ids_forward_micro_batch,
@@ -428,7 +422,7 @@ def _criterion(outputs, inputs):
428422

429423
loss, _ = self.policy_loss_fn(
430424
action_log_probs,
431-
old_action_log_probs,
425+
old_action_log_probs_micro_batch,
432426
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
433427
per_token_kl,
434428
action_mask_forward_micro_batch,
@@ -468,7 +462,7 @@ def _criterion(outputs, inputs):
468462
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
469463
advantages = all_reduce_mean(advantages.mean(), self.plugin)
470464
response_length = all_reduce_mean(response_length.mean(), self.plugin)
471-
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
465+
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
472466
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
473467
self.accum_entropy.add_(entropy.data)
474468
if self.policy_loss_fn.beta > 0:

0 commit comments

Comments
 (0)