Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def __post_init__(self):
}
os.environ.update(env)

# compile loss
logger.info("Compiling loss")
self.loss = torch.compile(self.loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any circumstance under which this command would fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cant think of one in our scenario, but if/when this happens, we can fix it


@endpoint
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
Expand Down
18 changes: 16 additions & 2 deletions src/forge/util/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def compute_logprobs(
) -> torch.Tensor:
"""
Computes the log probabilities of the input tokens given the model logits and temperature.
Always converts inputs to fp32 for numerical stability

Args:
logits (`torch.Tensor`):
Expand All @@ -65,10 +66,23 @@ def compute_logprobs(
temperature (`float`, *optional*, defaults to 1.0):
The temperature value for scaling logits before computing log probabilities.

Returns:
logprobs: [batch, seq_len] log probabilities for each token
"""
# Ignore the last token from logits because it predicts the next token (-1)
# And align logits with the input tokens length.
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device)
scaled_logits = logits / temperature
logprobs = selective_log_softmax(scaled_logits, input_ids)
return logprobs

# Convert to fp32 for numerical stability
scaled_logits_fp32 = scaled_logits.float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: what's the dtype for scaled_logits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float becomes torch.float32


# get per-token log probs
batch_size, seq_len, vocab_size = scaled_logits_fp32.shape
log_probs = -F.cross_entropy(
scaled_logits_fp32.reshape(-1, vocab_size),
input_ids.reshape(-1),
reduction="none",
)

return log_probs.reshape(batch_size, seq_len)
Loading