-
Notifications
You must be signed in to change notification settings - Fork 17
[Memory optm] loss using torch + compile #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
22e50d0
294dd29
9b6f3b5
8723a5a
d8adf06
6b065d6
c3a4586
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ def compute_logprobs( | |
) -> torch.Tensor: | ||
""" | ||
Computes the log probabilities of the input tokens given the model logits and temperature. | ||
Always converts inputs to fp32 for numerical stability | ||
|
||
Args: | ||
logits (`torch.Tensor`): | ||
|
@@ -65,10 +66,23 @@ def compute_logprobs( | |
temperature (`float`, *optional*, defaults to 1.0): | ||
The temperature value for scaling logits before computing log probabilities. | ||
|
||
Returns: | ||
logprobs: [batch, seq_len] log probabilities for each token | ||
""" | ||
# Ignore the last token from logits because it predicts the next token (-1) | ||
# And align logits with the input tokens length. | ||
logits = logits[:, -input_ids.size(1) - 1 : -1, :].to(input_ids.device) | ||
scaled_logits = logits / temperature | ||
logprobs = selective_log_softmax(scaled_logits, input_ids) | ||
return logprobs | ||
|
||
# Cast up to fp32 for numerical stability | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I would change this to something like "ensure logits are in fp32" b/c they actually could already be in fp32 and no need for "Casting up" |
||
scaled_logits_fp32 = scaled_logits.float() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Noob question: what's the dtype for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. float becomes torch.float32 |
||
|
||
# get per-token log probs | ||
batch_size, seq_len, vocab_size = scaled_logits_fp32.shape | ||
log_probs = -F.cross_entropy( | ||
felipemello1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
scaled_logits_fp32.reshape(-1, vocab_size), | ||
input_ids.reshape(-1).long(), | ||
reduction="none", | ||
) | ||
|
||
return log_probs.reshape(batch_size, seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any circumstance under which this command would fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cant think of one in our scenario, but if/when this happens, we can fix it