Skip to content

Conversation

krammnic
Copy link
Contributor

@krammnic krammnic commented Oct 7, 2025

Resolves #298

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 7, 2025
def forward_backward(
self, inputs: dict[str, Tensor], targets: dict[str, Tensor]
) -> Tensor:
) -> Tensor | LossMetrics:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be tuple[..,..]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should!

@krammnic
Copy link
Contributor Author

krammnic commented Oct 7, 2025

@casteryh let's merge

self.beta = beta

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we log the KL divergence minus padding tokens? May have to move that op up in the loss function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adressed

import torch
from torch import nn

from forge.data_models.loss_metrics import LossMetrics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is a fully fleshed out data model we want to use.

For now could we just define a loose type in this file and shove the metrics in that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done it with data_model, because we might want to log some other things from different losses in future (margins from DPO loss for instance).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Log KL Divergence in GRPO Loss function

3 participants