-
Notifications
You must be signed in to change notification settings - Fork 174
Fix numerical stability of test_gemm_common.py #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f1d24fe
to
b11b6fc
Compare
Relax tensor compare threshold Signed-off-by: Chenjie Luo <[email protected]>
b11b6fc
to
e500c75
Compare
WalkthroughAdds an autouse pytest fixture to set deterministic seeds before each test and relaxes tolerance thresholds in test_dynamic_gemm comparisons from atol/3 to atol/2 within the same test module. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/gpu/torch/quantization/backends/test_gemm_common.py (2)
32-36
: Autouse seed fixture is good; consider removing redundant module-level seeding.
@pytest.fixture(autouse=True)
ensures per-test determinism. You also callset_seed()
at import time (Line 29). Double seeding is harmless but redundant and can confuse ordering. Prefer using the fixture only, or add a brief comment justifying the import-time call.
266-272
: Threshold relaxation looks fine; consider adding rtol and better failure diagnostics.Keeping
rtol
consistent with earlier checks and printing diffs will make failures actionable while staying flaky-resistant.- assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2) + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2, rtol=rtol), ( + f"dynamic: max|diff|={((output_dynamic_quant_gemm - output_dynamic_quant).abs()).amax()}" + ) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2, rtol=rtol), ( + f"calib: max|diff|={((output_calib_quant_gemm - output_calib_quant).abs()).amax()}" + ) assert torch.allclose( - output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2 + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2, rtol=rtol ) - assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2) + assert torch.allclose( + output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2, rtol=rtol + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tests/gpu/torch/quantization/backends/test_gemm_common.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/quantization/backends/test_gemm_common.py (1)
tests/_test_utils/torch_misc.py (1)
set_seed
(33-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: code-quality
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
Relax tensor compare threshold
What does this PR do?
Fix Unittest
Summary by CodeRabbit