diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4ffc63001..7a399e4f8 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -160,6 +160,10 @@ def __post_init__(self): } os.environ.update(env) + # compile loss + logger.info("Compiling loss") + self.loss = torch.compile(self.loss) + @endpoint async def setup(self): # TODO: update ForgeEngine to not use ForgeJobConfig diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index 2eca1fdd1..a65b86e96 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -56,6 +56,7 @@ def compute_logprobs( ) -> torch.Tensor: """ Computes the log probabilities of the input tokens given the model logits and temperature. + Always converts inputs to fp32 for numerical stability Args: logits (`torch.Tensor`): @@ -65,10 +66,23 @@ def compute_logprobs( temperature (`float`, *optional*, defaults to 1.0): The temperature value for scaling logits before computing log probabilities. + Returns: + logprobs: [batch, seq_len] log probabilities for each token """ # Ignore the last token from logits because it predicts the next token (-1) # And align logits with the input tokens length. logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) scaled_logits = logits / temperature - logprobs = selective_log_softmax(scaled_logits, input_ids) - return logprobs + + # Cast up to fp32 for numerical stability + scaled_logits_fp32 = scaled_logits.float() + + # get per-token log probs + batch_size, seq_len, vocab_size = scaled_logits_fp32.shape + logprobs = -F.cross_entropy( + scaled_logits_fp32.reshape(-1, vocab_size), + input_ids.reshape(-1).long(), + reduction="none", + ) + + return logprobs.reshape(batch_size, seq_len)