-
Notifications
You must be signed in to change notification settings - Fork 531
Tiny optimizations for moe #1717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fzyzcjy
wants to merge
106
commits into
flashinfer-ai:main
Choose a base branch
from
fzyzcjy:feat/speedup_moe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
106 commits
Select commit
Hold shift + click to select a range
498027c
more
fzyzcjy 8b4c44b
more
fzyzcjy db646ac
more
fzyzcjy 9c5e3ca
temp 4gpu
fzyzcjy fd64748
more
fzyzcjy 145efc4
more
fzyzcjy ed240ab
more
fzyzcjy 9178d57
temp
fzyzcjy 27d9053
Revert "temp"
fzyzcjy d39f3ec
tune_max_num_tokens 16k -> 32k
fzyzcjy 2890f7e
more
fzyzcjy dde43f6
more
fzyzcjy 671d1cd
more
fzyzcjy 0158022
fix instasll err
fzyzcjy cc512e5
Merge branch 'main-upstream' into feat/bench_cutlass_moe
fzyzcjy 0bd7a7f
Merge branch 'feat/hack_license' into feat/bench_cutlass_moe
fzyzcjy 20da055
more
fzyzcjy 99f5ff2
more
fzyzcjy 622f8ae
more
fzyzcjy eb55ef1
hack: mask some selected experts
fzyzcjy 723fea7
fix tune_max_num_tokens
fzyzcjy ec9cab1
hack findTotalEltsLessThanTarget
fzyzcjy 29d4170
more
fzyzcjy 20a361b
more
fzyzcjy d66ef25
more
fzyzcjy f31b592
writeSF
fzyzcjy 480057d
pragma unroll
fzyzcjy 742521e
Revert "writeSF"
fzyzcjy 5cb9936
Revert "pragma unroll"
fzyzcjy 5d7e4d3
redo pragma unroll
fzyzcjy d29f4f9
make hidden size const for everywhere
fzyzcjy 8543c83
inter_size constexpr for activation
fzyzcjy efb5214
unroll permute copy
fzyzcjy 920c1ac
change unroll
fzyzcjy b807967
fix wrong return
fzyzcjy 06d5fc9
simp input_sf
fzyzcjy 93f8690
try change order
fzyzcjy 18676b4
hack (should revert): temp rm padding
fzyzcjy 977b8ae
Revert "hack (should revert): temp rm padding"
fzyzcjy ac807d7
prefetch unpermuted_row
fzyzcjy ad285aa
temp hack: EXPAND_THREADS_PER_BLOCK 256->128
fzyzcjy 865e758
temp hack: EXPAND_THREADS_PER_BLOCK 256->32
fzyzcjy ae5376f
Revert "temp hack: EXPAND_THREADS_PER_BLOCK 256->32"
fzyzcjy c3d3e9f
temp hack: EXPAND_THREADS_PER_BLOCK=128 + blocks=x2
fzyzcjy 0910742
bench kineto
fzyzcjy b15fa24
chore: rm log
fzyzcjy 516bf70
chore: log
fzyzcjy 630b482
chore: more tests
fzyzcjy 54757d7
revert kernel to 09:10
fzyzcjy ab88987
chore bench
fzyzcjy 704c1c5
hack: only enable "thread/=2, block*=2"
fzyzcjy ed6178c
hack: 64thread
fzyzcjy bc48740
Revert "hack: 64thread"
fzyzcjy e5b01db
unroll topk in unpermute kernel
fzyzcjy 4106c12
unpermute use AlignedArray
fzyzcjy 8c57702
Revert "unpermute use AlignedArray"
fzyzcjy a6c4d38
hack: manual vectorize
fzyzcjy 1fddc58
more manual vectorize
fzyzcjy 09b5d01
Revert "more manual vectorize"
fzyzcjy 5f76d06
Revert "hack: manual vectorize"
fzyzcjy d7631b1
hack: unpermute, maxnreg=32
fzyzcjy 1c6342c
Revert "hack: unpermute, maxnreg=32"
fzyzcjy 6ba3ab4
mv load ordering
fzyzcjy 6bcaa69
Revert "mv load ordering"
fzyzcjy b1bb205
make orig_cols constexpr
fzyzcjy 6ca1b1d
hack rm trap
fzyzcjy fec2f44
Revert "hack rm trap"
fzyzcjy 44d390e
Revert "make orig_cols constexpr"
fzyzcjy 14d26e8
cp: vectorize
fzyzcjy 4480029
cp: mv load ordering
fzyzcjy 80b704e
hack: rm bias handling (incorrect? why is it used?)
fzyzcjy f954a1f
hack: read 8 - compute 8, instead of 8x(read 1 compute 1)
fzyzcjy 9603cf7
naive handle enable_input_buf
fzyzcjy 2e52a49
enable_input_buf use bitwise op
fzyzcjy 1c75b9b
hack: unpermute, maxnreg=64
fzyzcjy 4e40624
doActivationKernel reg=32
fzyzcjy ea5a594
Revert "doActivationKernel reg=32"
fzyzcjy a5d4189
hack: acti blocks 8->6
fzyzcjy b52391a
Revert "hack: acti blocks 8->6"
fzyzcjy 7853d15
hack: acti - infinite num blocks
fzyzcjy f670fa4
hack: acti - mid num blocks
fzyzcjy 50378ba
Revert "hack: acti - mid num blocks"
fzyzcjy 9d1456a
Revert "hack: acti - infinite num blocks"
fzyzcjy ec89f0c
temp rm all
fzyzcjy 783120b
change test
fzyzcjy 7b0f471
Revert "temp rm all"
fzyzcjy 5873bb3
ARR_LENGTH_CONST
fzyzcjy bb7a97a
Revert "ARR_LENGTH_CONST"
fzyzcjy 8c719e6
hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len
fzyzcjy fece864
hack: unroll(4)
fzyzcjy b9bb8c7
Revert "hack: unroll(4)"
fzyzcjy 7e6c876
Revert "hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len"
fzyzcjy 06003aa
hack NUM_EXPERTS_PER_NODE_CONST
fzyzcjy afa5a61
temp: 4gpu bench
fzyzcjy 1583eb0
fix
fzyzcjy 3704820
temp rm all
fzyzcjy 3a74536
partial cp
fzyzcjy 348a536
cp change-block-thread, pragma-unroll, mv-if-check
fzyzcjy 0e62fa8
Revert "cp change-block-thread, pragma-unroll, mv-if-check"
fzyzcjy 941b68c
enable all except for 21:06
fzyzcjy 47f3d50
Merge remote-tracking branch 'upstream/main' into feat/speedup_moe
fzyzcjy d504c61
Revert "feat: Benchmark mm_fp4 mxfp4 support and gemm autotune supporβ¦
fzyzcjy 822ae9b
Revert "Enabled alpha with the mx_fp4 format (#1688)"
fzyzcjy 299caf5
enable change-thread-block
fzyzcjy d83a3cb
enable mv-unpermuted_row_to_permuted_row
fzyzcjy 68b8e6d
Revert "enable mv-unpermuted_row_to_permuted_row"
fzyzcjy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,20 +29,24 @@ | |
FLOAT4_E2M1_MAX = 6.0 | ||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
|
||
num_ranks = 4 | ||
|
||
test_configs = [ | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 256, | ||
# "top_k": 8, | ||
# "intermediate_size": 256, | ||
# }, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 256, | ||
"top_k": 8, | ||
"intermediate_size": 256, | ||
}, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 32, | ||
"num_experts": num_experts, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
}, | ||
} | ||
for num_experts in [ | ||
256 // num_ranks, | ||
] | ||
] | ||
|
||
|
||
|
@@ -131,6 +135,13 @@ def bench_cutlass_fused_moe( | |
router_logits = torch.randn(m, e, dtype=otype).cuda() | ||
routing_weights, selected_experts = compute_routing(router_logits, top_k) | ||
|
||
if 1: | ||
print("HACK: mask some selected_experts") | ||
selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999 | ||
|
||
tune_max_num_tokens = batch_size | ||
print(f"HACK: {tune_max_num_tokens=}") | ||
|
||
flash_output = torch.zeros_like(x) | ||
|
||
quant_scales = [ | ||
|
@@ -143,6 +154,7 @@ def bench_cutlass_fused_moe( | |
] | ||
hidden_states = x | ||
hidden_states, input_sf = fp4_quantize(x, a1_gs) | ||
print(f"{hidden_states.shape=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# Warmup | ||
for _ in range(3): | ||
|
@@ -156,7 +168,7 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
tune_max_num_tokens=16384, | ||
tune_max_num_tokens=tune_max_num_tokens, | ||
) | ||
|
||
if not skip_autotune: | ||
|
@@ -171,10 +183,20 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
tune_max_num_tokens=16384, | ||
tune_max_num_tokens=tune_max_num_tokens, | ||
) | ||
ms_list = bench_gpu_time( | ||
lambda: fused_moe.cutlass_fused_moe( | ||
|
||
counter = 0 | ||
|
||
def f(): | ||
nonlocal counter | ||
counter += 1 | ||
|
||
if counter == 10: | ||
print("hi call cudaProfilerStart") | ||
torch.cuda.cudart().cudaProfilerStart() | ||
|
||
fused_moe.cutlass_fused_moe( | ||
hidden_states, | ||
selected_experts.to(torch.int), | ||
routing_weights, | ||
|
@@ -184,14 +206,29 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
), | ||
) | ||
) | ||
|
||
if counter == 10: | ||
print("hi call cudaProfilerStop") | ||
torch.cuda.cudart().cudaProfilerStop() | ||
|
||
ms_list = bench_gpu_time(f) | ||
median_ms = np.median(ms_list) | ||
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}") | ||
print( | ||
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}" | ||
) | ||
|
||
from flashinfer.testing.utils import bench_kineto | ||
for _ in range(5): | ||
ts = bench_kineto( | ||
f, | ||
("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), | ||
suppress_kineto_output=False, | ||
num_tests=100, | ||
) | ||
print(f"Kineto output: ts_ms={['%.3f' % (t * 1000) for t in ts]}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
@@ -201,7 +238,7 @@ def bench_cutlass_fused_moe( | |
help="Update the config file with the new profiling results", | ||
) | ||
parser.add_argument( | ||
"--num-tokens", type=int, default=32, help="Number of tokens to profile" | ||
"--num-tokens", type=int, default=32768 * num_ranks, help="Number of tokens to profile" | ||
) | ||
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") | ||
args = parser.parse_args() | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,6 @@ | |
"out_dtype", | ||
"mma_sm", | ||
"use_128x4_sf_layout", | ||
"use_nvfp4", | ||
], | ||
"moe": [ | ||
"num_tokens", | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
if 1:
block appears to be temporary code for debugging and testing, as indicated by the "HACK" print statements. This block, including the hardcoded logic for masking experts and settingtune_max_num_tokens
, should be removed before merging.