Skip to content

Commit c640d37

Browse files
committed
Set KL divergence coefficient to zero in loss function
Changed beta parameter from 0.1 to 0.0 in simple_grpo_loss to remove the KL divergence penalty term from the loss.
1 parent 6186f9f commit c640d37

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sandbox/grpo_language/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def simple_grpo_loss(
125125
ref_logprobs: torch.Tensor,
126126
advantages: torch.Tensor,
127127
padding_mask: torch.Tensor,
128-
beta: float = 0.1,
128+
beta: float = 0.0,
129129
) -> torch.Tensor:
130130
logprobs: torch.Tensor = compute_logprobs(logits, response)
131131
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1

0 commit comments

Comments
 (0)