-
Notifications
You must be signed in to change notification settings - Fork 436
[WIP] Add benchmark scripts to different moe gemms #1315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
fzyzcjy
wants to merge
28
commits into
flashinfer-ai:main
Choose a base branch
from
fzyzcjy:feat/dev_branch_20250724
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
a5390df
more
fzyzcjy 94c6d27
more
fzyzcjy 64b0032
more
fzyzcjy cfcc7ce
more
fzyzcjy 0e63dbe
more
fzyzcjy 72ec2c4
more
fzyzcjy 8caafde
more
fzyzcjy 5cbc7d0
more
fzyzcjy ab252c7
more
fzyzcjy 8629008
more
fzyzcjy 4624913
more
fzyzcjy 9de0387
more
fzyzcjy c89cf7c
more
fzyzcjy adf2cb1
more
fzyzcjy 57eda4a
more
fzyzcjy 8b606b8
more
fzyzcjy 7b78322
more
fzyzcjy 8c53c59
more
fzyzcjy c4c8a07
more
fzyzcjy a19273c
more
fzyzcjy a129862
more
fzyzcjy 3d37611
more
fzyzcjy 62eaa46
more
fzyzcjy dbecb4a
more
fzyzcjy 91863e4
more
fzyzcjy 3baee7e
more
fzyzcjy 57a18b6
more
fzyzcjy 806cf11
more
fzyzcjy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,10 @@ | |
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import json | ||
import os | ||
import sys | ||
import time | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
|
@@ -21,6 +25,7 @@ | |
import flashinfer | ||
import flashinfer.fused_moe as fused_moe | ||
from flashinfer import fp4_quantize | ||
from flashinfer.testing.utils import bench_kineto | ||
|
||
BATCH_SIZES = [ | ||
1, | ||
|
@@ -35,7 +40,9 @@ | |
96, | ||
128, | ||
256, | ||
384, # NOTE ADD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
512, | ||
768, # NOTE ADD | ||
1024, | ||
1536, | ||
2048, | ||
|
@@ -53,18 +60,40 @@ | |
FP8_DTYPE = torch.float8_e4m3fn | ||
|
||
test_configs = [ | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 256, | ||
"top_k": 8, | ||
"intermediate_size": 256, | ||
}, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 32, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
}, | ||
# NOTE MODIFIED ADD | ||
*[ | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": num_experts, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
} | ||
for num_experts in [ | ||
288 // 1, | ||
288 // 2, | ||
288 // 4, | ||
288 // 8, | ||
288 // 16, | ||
288 // 32, | ||
# TODO support | ||
# 288 // 48, | ||
# 288 // 72, | ||
] | ||
], | ||
|
||
# --- old --- | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 256, | ||
# "top_k": 8, | ||
# "intermediate_size": 256, | ||
# }, | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 32, | ||
# "top_k": 8, | ||
# "intermediate_size": 2048, | ||
# }, | ||
] | ||
|
||
|
||
|
@@ -182,7 +211,22 @@ def bench_cutlass_fused_moe( | |
input_sf=input_sf, | ||
output=flash_output, | ||
) | ||
ms = do_bench( | ||
# NOTE MODIFIED | ||
# ms = do_bench( | ||
# lambda: fused_moe.cutlass_fused_moe( | ||
# hidden_states, | ||
# selected_experts.to(torch.int), | ||
# routing_weights, | ||
# w1_q.contiguous().view(torch.long), | ||
# w2_q.contiguous().view(torch.long), | ||
# otype, | ||
# quant_scales=quant_scales, | ||
# input_sf=input_sf, | ||
# output=flash_output, | ||
# ) | ||
# ) | ||
trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR") | ||
[time_gemm1, time_gemm2] = bench_kineto( | ||
lambda: fused_moe.cutlass_fused_moe( | ||
hidden_states, | ||
selected_experts.to(torch.int), | ||
|
@@ -193,12 +237,25 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
) | ||
), | ||
kernel_names="cutlass13device_kernelINS_4gemm6kernel", | ||
num_kernels_per_period=2, | ||
trace_path=f"{trace_dir}/{time.time()}.trace.json.gz" if trace_dir else None, | ||
) | ||
print( | ||
f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" | ||
) | ||
print(f"execution time: {ms}ms") | ||
|
||
# NOTE MODIFIED | ||
print(f"MAIN_OUTPUT=" + json.dumps(dict( | ||
batch_size=batch_size, | ||
num_experts=num_experts, | ||
top_k=top_k, | ||
intermediate_size=intermediate_size, | ||
time_gemm1_us=time_gemm1 * 1e6, | ||
time_gemm2_us=time_gemm2 * 1e6, | ||
))) | ||
Comment on lines
+247
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# print( | ||
# f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" | ||
# ) | ||
# print(f"execution time: {ms}ms") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider removing the unused
sys
andtime
imports to maintain clean and readable code.