-
Notifications
You must be signed in to change notification settings - Fork 676
Improve sampling benchmarks. #2374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Improve sampling benchmarks. #2374
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
✨ Finishing touches🧪 Generate unit tests (beta)
Tip 🧪 Unit Test Generation v2 is now available!We have significantly improved our unit test generation capabilities. To enable: Add this to your reviews:
finishing_touches:
unit_tests:
enabled: trueTry it out by using the Have feedback? Share your thoughts on our Discord thread! Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @vincentzed, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the FlashInfer benchmarking framework by adding dedicated routines for evaluating the performance of various sampling strategies. This includes fundamental sampling from probability distributions, advanced techniques like Top-P (nucleus) and Top-K sampling, their combined application, and utility functions for probability renormalization and logit masking. The integration ensures that users can now comprehensively assess the efficiency of these critical components in large language model inference workflows, providing valuable insights for optimization and development. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces comprehensive benchmark tests for various sampling routines in FlashInfer, which is a great addition for performance tracking. The changes include a new sampling.py routine file, updates to the main benchmark script and utilities to integrate the new tests, and documentation updates in the README.md.
The code is well-structured, but I have a few suggestions to improve maintainability and correctness:
- Refactor duplicated code in
flashinfer_benchmark_utils.pyfor defining supported compute capabilities. - Adhere to PEP 8 naming conventions for functions in the new
sampling.pyfile. - Add a reference check to the
testTopPRenormProbsbenchmark for correctness validation.
Details are in the specific comments. Overall, this is a solid contribution.
| # SAMPLING - supported on all architectures | ||
| "sampling_from_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_p_sampling_from_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_k_sampling_from_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_k_top_p_sampling_from_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_k_renorm_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_p_renorm_probs": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| "top_k_mask_logits": { | ||
| "7.5": ["cuda"], | ||
| "8.0": ["cuda"], | ||
| "8.6": ["cuda"], | ||
| "8.9": ["cuda"], | ||
| "9.0": ["cuda"], | ||
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| }, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of code duplication here for defining the supported compute capabilities for sampling routines. All sampling routines share the same support matrix. To improve maintainability and reduce code duplication, you can define the support dictionary once and reuse it for all sampling routines. A dictionary comprehension can make this more concise.
# SAMPLING - supported on all architectures
**{
routine: {
"7.5": ["cuda"],
"8.0": ["cuda"],
"8.6": ["cuda"],
"8.9": ["cuda"],
"9.0": ["cuda"],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
}
for routine in benchmark_apis["sampling"]
},
}| if args.routine == "sampling_from_probs": | ||
| return testSamplingFromProbs(args) | ||
| if args.routine == "top_p_sampling_from_probs": | ||
| return testTopPSamplingFromProbs(args) | ||
| if args.routine == "top_k_sampling_from_probs": | ||
| return testTopKSamplingFromProbs(args) | ||
| if args.routine == "top_k_top_p_sampling_from_probs": | ||
| return testTopKTopPSamplingFromProbs(args) | ||
| if args.routine == "top_k_renorm_probs": | ||
| return testTopKRenormProbs(args) | ||
| if args.routine == "top_p_renorm_probs": | ||
| return testTopPRenormProbs(args) | ||
| if args.routine == "top_k_mask_logits": | ||
| return testTopKMaskLogits(args) | ||
| raise ValueError(f"Unsupported routine: {args.routine}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function names in this file (testSamplingFromProbs, testTopPSamplingFromProbs, etc.) do not follow the PEP 8 style guide, which recommends snake_case for function names. For consistency with the rest of the Python ecosystem and to improve readability, please rename these functions and their definitions. For example, testSamplingFromProbs should be test_sampling_from_probs.
| if args.routine == "sampling_from_probs": | |
| return testSamplingFromProbs(args) | |
| if args.routine == "top_p_sampling_from_probs": | |
| return testTopPSamplingFromProbs(args) | |
| if args.routine == "top_k_sampling_from_probs": | |
| return testTopKSamplingFromProbs(args) | |
| if args.routine == "top_k_top_p_sampling_from_probs": | |
| return testTopKTopPSamplingFromProbs(args) | |
| if args.routine == "top_k_renorm_probs": | |
| return testTopKRenormProbs(args) | |
| if args.routine == "top_p_renorm_probs": | |
| return testTopPRenormProbs(args) | |
| if args.routine == "top_k_mask_logits": | |
| return testTopKMaskLogits(args) | |
| raise ValueError(f"Unsupported routine: {args.routine}") | |
| if args.routine == "sampling_from_probs": | |
| return test_sampling_from_probs(args) | |
| if args.routine == "top_p_sampling_from_probs": | |
| return test_top_p_sampling_from_probs(args) | |
| if args.routine == "top_k_sampling_from_probs": | |
| return test_top_k_sampling_from_probs(args) | |
| if args.routine == "top_k_top_p_sampling_from_probs": | |
| return test_top_k_top_p_sampling_from_probs(args) | |
| if args.routine == "top_k_renorm_probs": | |
| return test_top_k_renorm_probs(args) | |
| if args.routine == "top_p_renorm_probs": | |
| return test_top_p_renorm_probs(args) | |
| if args.routine == "top_k_mask_logits": | |
| return test_top_k_mask_logits(args) | |
| raise ValueError(f"Unsupported routine: {args.routine}") |
benchmarks/routines/sampling.py
Outdated
| def testTopPRenormProbs(args): | ||
| """Test top_p_renorm_probs API. | ||
|
|
||
| This test: | ||
| 1. Generates random probability distributions | ||
| 2. Runs top_p_renorm_probs (renormalize by top-p thresholding) | ||
| 3. Measures performance metrics | ||
|
|
||
| Args: | ||
| args: Parsed command line arguments containing test configuration | ||
|
|
||
| Returns: | ||
| dict: List of dictionaries containing performance results | ||
|
|
||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The testTopPRenormProbs function is missing a reference check (refcheck) to validate the correctness of the implementation. Other similar test functions in this file, like testTopKRenormProbs, include this check. Adding a reference implementation using PyTorch and comparing the results would increase confidence in the benchmark's correctness. You can find an example of a PyTorch reference implementation for top-p in tests/utils/test_sampling.py.
|
Hi @vincentzed would you mind checking the following files:
and see whether there are some components we can reuse? |
Signed-off-by: vincentzed <[email protected]> style check Signed-off-by: vincentzed <[email protected]> minor style change Signed-off-by: vincentzed <[email protected]> more Signed-off-by: vincentzed <[email protected]>
df66f58 to
fddeca5
Compare
Signed-off-by: vincentzed <[email protected]>
📌 Description
Later, we will also add topk
flashinfer.topk, since the only test in codebase are in tests/utils/test_topk.py and no performance understanding that is tracked.Motivation: sgl-project/sglang#17243 and other analysis to see if sampling can be improved (relatively trivial time still)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes