-
Notifications
You must be signed in to change notification settings - Fork 676
Fix autotuner oom #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix autotuner oom #2442
Conversation
Summary of ChangesHello @zack041, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the robustness of the autotuner by introducing comprehensive Out-Of-Memory (OOM) error handling. It prevents the autotuner from crashing when GPU memory is exhausted during profiling, ensuring a more stable and resilient tuning process. Instead of failing, the system now clears the CUDA cache and reverts to a safe default configuration, allowing the application to continue operating without interruption. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdded robust exception handling in Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Autotuner as Autotuner.choose_one
participant Runner as Runner.profile
participant CUDA
participant Cache
Caller->>Autotuner: request best (runner,tactic)
Autotuner->>Cache: lookup cached (runner,tactic)
alt cache miss / needs profiling
Autotuner->>Runner: profile(tactic_i)
Runner->>CUDA: allocate / run kernel
alt torch.cuda.OutOfMemoryError
CUDA-->>Runner: OOM error
Runner-->>Autotuner: raise OOM
Autotuner->>CUDA: torch.cuda.empty_cache()
Autotuner-->>Caller: return fallback (runner, tactic=-1)
else other Exception
Runner-->>Autotuner: exception
Autotuner->>Cache: record failed profiling (time=∞)
Autotuner->>Runner: continue with next tactic
else success
Runner-->>Autotuner: time_measured
Autotuner->>Cache: update best (runner,tactic)
Autotuner-->>Caller: return chosen (runner,tactic)
end
else cached
Cache-->>Autotuner: cached (runner,tactic)
Autotuner-->>Caller: return cached (runner,tactic)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces important graceful Out-Of-Memory (OOM) handling during the autotuning process, which enhances the robustness of the system. The implementation correctly wraps the profiling loop with a try-except block for torch.cuda.OutOfMemoryError, clears the CUDA cache, and falls back to a default tactic. However, the cache update and statistics increment logic is currently positioned such that it executes even for cache hits, leading to inaccurate statistics and redundant operations. This should be adjusted to only apply when a new optimal configuration is successfully profiled.
flashinfer/autotuner.py
Outdated
| if runner_id is not None: | ||
| # At least one valid (runner, tactic) pair is found | ||
| cache_key = AutoTuner._get_cache_key( | ||
| custom_op, runners[runner_id], p.get_opt_shapes(), tuning_config | ||
| ) | ||
| # inspect call stack | ||
| self.profiling_cache[cache_key] = (runner_id, tactic, p) | ||
| self.stats.tuned_op_successful_configs[custom_op] = ( | ||
| self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1 | ||
| ) | ||
| logger.debug( | ||
| f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code, which updates the profiling cache and statistics, is currently placed outside the if not is_cache_hit: condition. This means it will execute even when a configuration is retrieved from the cache, leading to incorrect tuned_op_successful_configs counts and redundant cache updates. It should only execute when a new best runner/tactic is found after profiling. Please move this block inside the if not is_cache_hit: block, at the same indentation level as min_time = float("inf") (line 475).
300008d to
2c43050
Compare
📌 Description
Add graceful OOM handling during autotuning. When
torch.cuda.OutOfMemoryErroroccurs, the autotuner now clears CUDA cache and falls back to the default tactic(runners[0], -1)instead of crashing. The try-except block wraps the entire profiling loop, covering methods like_prepare_input_tensors()that could also cause OOM. OOM from the inner profiling loop is raised to be caught by the outer exception handler.🔍 Related Issues
Fixes #2357
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
No tests added because OOM during autotuning is difficult to reliably reproduce in a test environment.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.