diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4ffc63001..298c83489 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -47,6 +47,8 @@ from forge.controller import ForgeActor from forge.data.utils import batch_to_device + +from forge.data_models.loss_metrics import LossMetrics from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer @@ -179,7 +181,7 @@ async def setup(self): def forward_backward( self, inputs: dict[str, Tensor], targets: dict[str, Tensor] - ) -> Tensor: + ) -> Tensor | LossMetrics: model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims @@ -236,12 +238,12 @@ def forward_backward( assert len(model_parts) == 1 with self.engine.maybe_enable_amp: logits = model_parts[0](**inputs) - loss = self.loss(logits, **targets) + loss, loss_metrics = self.loss(logits, **targets) # need to free to before bwd to avoid peaking memory del logits loss.backward() - return loss + return loss, loss_metrics @endpoint async def train_step( @@ -264,7 +266,7 @@ async def train_step( # self.model, # self.data_parallel_size, # ) as grad_acc: - loss = self.forward_backward(local_inputs, local_targets) + loss, loss_metrics = self.forward_backward(local_inputs, local_targets) torch.distributed.all_reduce(loss) t.step("forward_backward") @@ -287,11 +289,24 @@ async def train_step( record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM) record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN) - # TODO: Extract actual KL divergence and policy entropy from the loss computation - # These are placeholder values until the loss function exposes these metrics - # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN) - # record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD) - # record_metric("rl_trainer/step/avg_policy_entropy", 0.0, Reduce.MEAN) + if loss_metrics.kl_divergence.sum().item() != 0: + record_metric( + "rl_trainer/step/avg_kl_divergence", + loss_metrics.kl_divergence, + Reduce.MEAN, + ) + record_metric( + "rl_trainer/step/std_kl_divergence", + loss_metrics.kl_divergence, + Reduce.STD, + ) + + if loss_metrics.policy_entropy.sum().item() != 0: + record_metric( + "rl_trainer/step/avg_policy_entropy", + loss_metrics.policy_entropy, + Reduce.MEAN, + ) self.step += 1 self.engine.checkpointer.save( diff --git a/src/forge/losses/grpo_loss.py b/src/forge/losses/grpo_loss.py index 220367b47..8d4e85a03 100644 --- a/src/forge/losses/grpo_loss.py +++ b/src/forge/losses/grpo_loss.py @@ -7,6 +7,8 @@ import torch from torch import nn +from forge.data_models.loss_metrics import LossMetrics + class SimpleGRPOLoss(nn.Module): """Simplified GRPO Loss for simplified single step updates @@ -14,16 +16,27 @@ class SimpleGRPOLoss(nn.Module): https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. """ - def __init__(self, beta: float = 0.1): + def __init__(self, beta: float = 0.1) -> tuple[torch.Tensor, LossMetrics]: super().__init__() self.beta = beta - def forward(self, logprobs, ref_logprobs, advantages, padding_mask): + def forward( + self, logprobs, ref_logprobs, advantages, padding_mask + ) -> tuple[torch.tensor, LossMetrics]: kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages per_token_loss = -(per_token_policy_loss - self.beta * kl) loss = ( ((per_token_loss * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0)) ).mean() - return loss + + kl_padded = kl * padding_mask + # We want to save consistency with Reduce.MEAN in the trainer. + # Unfortunately, just kl * padding_mask is not working due to the different factor in the denominator + kl_padded = kl.sum(dim=1) / padding_mask.sum(dim=1).clamp(min=1.0) + + return loss, LossMetrics( + kl_divergence=kl_padded, policy_entropy=torch.tensor(0) + ) diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index c2dedd530..c79130a78 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -7,6 +7,8 @@ import torch from torch import nn +from forge.data_models.loss_metrics import LossMetrics + from forge.util.ops import selective_log_softmax @@ -28,7 +30,7 @@ def __init__(self): def forward( self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs - ): + ) -> tuple[torch.Tensor, LossMetrics]: trainer_log_probs = selective_log_softmax(trainer_logits, target_ids) target_mask = target_mask.detach() target_weights = target_weights @@ -47,4 +49,6 @@ def forward( numerator = (-trainer_log_probs * weighted_advantages * target_mask).sum() denominator = target_mask_sum - return numerator / denominator + return numerator / denominator, LossMetrics( + kl_divergence=torch.tensor(0), policy_entropy=torch.tensor(0) + )