-
Notifications
You must be signed in to change notification settings - Fork 459
Changed tests from logits to topk logprobs #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 1e-1, # 1e-1 | ||
| 1e-1, # 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing all logprobs comparison, we can try setting it lower.
sglang only has atol and sets it to 5e-2 (decode_tolerance)
verl sets (atol, rtol) = (1e-2, 1e-5), but it's mean of all logprobs not topk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not work with lower tolerance.
For gemma3, it passes when atol=1e-1 and rtol=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this out with fp32, it fails for most of the models where old logic for checking the logits is passing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are comparing values in log-space, the total tolerance here is actually relative tolerance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just check the rtol?
like: tolerance = rtol * torch.abs(tensor2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolute diff for two logprobs (logA - logB) = relative diff for two probs (A / B), which means the whole tolerance (atol + rtol * torch.abs(expected)) should be the maximum relative diff we can accept.
I think that's also why sglang only has a single tolerance in their test.
|
@Tcc0403 Can you have a look at the changes, I have tested it. |
check logprobs as well for consistency I'm planning to rewrite convergence tests so just ignore namings for now. |
Gotcha! |
What mean logprobs do you pick? I checked verl impl, they pick per-token logprobs for the given labels |
|
I tried top 20 logprobs and it was able to pass tests for all the models! @Tcc0403 |
|
The tolerance for gemma3 multimodal model had to be set high as it does not pass the tests for loss and topk_logprobs. |
Yeah, I think we can compromise with 1e-1 before further investigation in numerical issue. Just make them all green first unless there's an obvious mismatch. |
shimizust
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these changes!
Summary
Just testing out logprobs as mentioned in #742
It worked for the models where the test using logits was not working.
Also, tried to setup 1e-1 tolerance for qwen (previously 1) and it passed.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence