Skip to content

Commit ddda79c

Browse files
committed
add entropy
1 parent dba0c0c commit ddda79c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
346346
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
347347

348348
kl = []
349+
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
349350

350351
def _criterion(outputs, inputs):
351352
action_logits = outputs.logits
@@ -425,6 +426,20 @@ def _criterion(outputs, inputs):
425426
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
426427
mean_kl.append(kl)
427428
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
429+
mini_batch_entropies.append(
430+
all_reduce_mean(
431+
(
432+
(
433+
(
434+
entropy_from_logits(policy_model_logits[:, -num_action:])
435+
* action_mask_forward_micro_batch
436+
).sum(-1)
437+
)
438+
/ action_mask_forward_micro_batch.sum(-1)
439+
).detach(),
440+
self.plugin,
441+
)
442+
)
428443
else:
429444
policy_model_logits = self.policy_model(
430445
input_ids=input_ids_forward_micro_batch,

0 commit comments

Comments
 (0)