From 22e50d0153fef05bc6b107cdd4524ae6374cf281 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 7 Oct 2025 12:56:20 -0700 Subject: [PATCH 1/5] loss using torch --- src/forge/actors/trainer.py | 4 ++++ src/forge/util/ops.py | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4ffc63001..7a399e4f8 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -160,6 +160,10 @@ def __post_init__(self): } os.environ.update(env) + # compile loss + logger.info("Compiling loss") + self.loss = torch.compile(self.loss) + @endpoint async def setup(self): # TODO: update ForgeEngine to not use ForgeJobConfig diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index 2eca1fdd1..bc64b9fea 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -56,6 +56,7 @@ def compute_logprobs( ) -> torch.Tensor: """ Computes the log probabilities of the input tokens given the model logits and temperature. + Always converts inputs to fp32 for numerical stability Args: logits (`torch.Tensor`): @@ -65,10 +66,23 @@ def compute_logprobs( temperature (`float`, *optional*, defaults to 1.0): The temperature value for scaling logits before computing log probabilities. + Returns: + logprobs: [batch, seq_len] log probabilities for each token """ # Ignore the last token from logits because it predicts the next token (-1) # And align logits with the input tokens length. logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) scaled_logits = logits / temperature - logprobs = selective_log_softmax(scaled_logits, input_ids) - return logprobs + + # Convert to fp32 for numerical stability + scaled_logits_fp32 = scaled_logits.float() + + # get per-token log probs + batch_size, seq_len, vocab_size = scaled_logits_fp32.shape + log_probs = -F.cross_entropy( + scaled_logits_fp32.reshape(-1, vocab_size), + input_ids.reshape(-1), + reduction="none", + ) + + return log_probs.reshape(batch_size, seq_len) From 294dd293d5bb5ef0ba62dedf0a0ff113bb7d4e6c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 7 Oct 2025 15:53:56 -0700 Subject: [PATCH 2/5] add long to pass tests --- src/forge/util/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index bc64b9fea..14a2a8f28 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -81,7 +81,7 @@ def compute_logprobs( batch_size, seq_len, vocab_size = scaled_logits_fp32.shape log_probs = -F.cross_entropy( scaled_logits_fp32.reshape(-1, vocab_size), - input_ids.reshape(-1), + input_ids.reshape(-1).long(), reduction="none", ) From 9b6f3b5f3445a8914af3f63e281b804e4decf260 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 7 Oct 2025 20:28:03 -0400 Subject: [PATCH 3/5] Update src/forge/util/ops.py Co-authored-by: Jiyue Wang --- src/forge/util/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index 14a2a8f28..c0c55ad64 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -74,7 +74,7 @@ def compute_logprobs( logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) scaled_logits = logits / temperature - # Convert to fp32 for numerical stability + # Cast up to fp32 for numerical stability scaled_logits_fp32 = scaled_logits.float() # get per-token log probs From 8723a5ac82f9b6aff4d3428a094e99ab87f168ac Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:39:52 -0400 Subject: [PATCH 4/5] Update src/forge/util/ops.py Co-authored-by: Joe Cummings --- src/forge/util/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index c0c55ad64..4f8d335c6 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -79,7 +79,7 @@ def compute_logprobs( # get per-token log probs batch_size, seq_len, vocab_size = scaled_logits_fp32.shape - log_probs = -F.cross_entropy( + logprobs = -F.cross_entropy( scaled_logits_fp32.reshape(-1, vocab_size), input_ids.reshape(-1).long(), reduction="none", From c3a4586770f33dce242fb3e2bd24e8ca52104e48 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 9 Oct 2025 07:08:03 -0700 Subject: [PATCH 5/5] update arg name --- src/forge/util/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index 4f8d335c6..a65b86e96 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -85,4 +85,4 @@ def compute_logprobs( reduction="none", ) - return log_probs.reshape(batch_size, seq_len) + return logprobs.reshape(batch_size, seq_len)