-
Notifications
You must be signed in to change notification settings - Fork 531
Tiny optimizations for moe #1717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This reverts commit 9178d57.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @fzyzcjy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on refining and optimizing the Mixture-of-Experts (MoE) implementation within the FlashInfer library. It introduces an experimental, warp-optimized CUDA kernel to potentially speed up element counting, alongside significant enhancements to the MoE benchmarking script. These enhancements include dynamic configuration generation, temporary expert masking for testing, and the integration of CUDA profiling hooks, all aimed at facilitating deeper performance analysis and identifying further optimization opportunities. Additionally, debug logging has been improved in key areas to aid development and understanding of execution flow.
Highlights
- MoE Kernel Optimization Experimentation: Introduced an experimental, warp-optimized CUDA kernel (
findTotalEltsLessThanTarget_v2
) for counting elements, alongside the existing binary search approach, suggesting exploration into performance improvements for Mixture-of-Experts (MoE) operations. - Benchmarking Enhancements: Added temporary "HACK" code to the MoE benchmark script for masking selected experts and dynamically setting
tune_max_num_tokens
, along with integratingcudaProfilerStart
andcudaProfilerStop
calls for detailed performance analysis. - Dynamic Benchmark Configuration: Modified the
bench_cutlass_fused_moe.py
script to dynamically generatenum_experts
based onnum_ranks
and updated the default--num-tokens
argument, making benchmark configurations more flexible. - Debug Logging Additions: Incorporated new debug print statements in
flashinfer/fused_moe/core.py
andflashinfer/autotuner.py
to provide more visibility into tensor shapes and autotuner cache hits during execution. - Pyproject.toml Changes: Commented out license-related fields in
pyproject.toml
, which might be a temporary measure or part of a larger licensing review.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands on the current page.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several optimizations for MoE, but it is clearly a work-in-progress and contains a significant amount of temporary code for debugging and profiling. This includes if 1:
blocks, "HACK" print statements, and hardcoded CUDA profiler calls in the benchmark file. The C++ code includes an experimental, non-generic function with asm("trap;")
which is unsafe. Furthermore, the project's license information has been commented out in pyproject.toml
, which is a critical issue. All temporary and debug code must be removed, the unsafe C++ function needs to be made robust or removed, and the license information must be restored before this PR can be merged.
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { | ||
constexpr int ARR_LENGTH_CONST = 128; | ||
if (arr_length != ARR_LENGTH_CONST) { | ||
asm("trap;"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function findTotalEltsLessThanTarget_v2
has a hardcoded ARR_LENGTH_CONST
and uses asm("trap;")
if the input array length does not match. This makes the function non-generic and unsafe for general use, as it will cause a crash for any other input size. This experimental implementation should be made more robust or removed if it's not ready for production.
requires-python = ">=3.9,<4.0" | ||
authors = [{ name = "FlashInfer team" }] | ||
license = "Apache-2.0" | ||
#license = "Apache-2.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" } | ||
dynamic = ["dependencies", "version"] | ||
license-files = ["LICENSE", "licenses/*"] | ||
#license-files = ["LICENSE", "licenses/*"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if 1: | ||
print("HACK: mask some selected_experts") | ||
selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999 | ||
|
||
tune_max_num_tokens = batch_size | ||
print(f"HACK: {tune_max_num_tokens=}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { | ||
return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); | ||
|
||
// return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); | ||
|
||
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); | ||
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); | ||
// if (out_v1 != out_v2) { | ||
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); | ||
// asm("trap;"); | ||
// } | ||
// return out_v1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
] | ||
hidden_states = x | ||
hidden_states, input_sf = fp4_quantize(x, a1_gs) | ||
print(f"{hidden_states.shape=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | ||
# NOTE ADD | ||
logger.debug( | ||
f"[AutoTunner]: HACK ADD cache hit {custom_op=} {input_shapes=}" | ||
) | ||
return runner, tactic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer/fused_moe/core.py
Outdated
print( | ||
"hi flashinfer cutlass_fused_moe " | ||
f"{input.shape=} {input.dtype=} " | ||
f"{token_selected_experts.shape=}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reverts commit f31b592.
b204a15
to
1583eb0
Compare
a34bb8d
to
348a536
Compare
This reverts commit 348a536.
…t. Restore mm_fp4 API behavior (flashinfer-ai#1706)" This reverts commit e8f5460.
This reverts commit bc42393. # Conflicts: # tests/test_mm_fp4.py
This reverts commit d83a3cb.
📌 Description
has speedup but still WIP, need to sleep nowEDIT: some speedup is in code commit history, does speedup but intro complexity thus I reverted them
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes