Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
498027c
more
fzyzcjy Sep 10, 2025
8b4c44b
more
fzyzcjy Sep 10, 2025
db646ac
more
fzyzcjy Sep 10, 2025
9c5e3ca
temp 4gpu
fzyzcjy Sep 10, 2025
fd64748
more
fzyzcjy Sep 10, 2025
145efc4
more
fzyzcjy Sep 10, 2025
ed240ab
more
fzyzcjy Sep 10, 2025
9178d57
temp
fzyzcjy Sep 10, 2025
27d9053
Revert "temp"
fzyzcjy Sep 10, 2025
d39f3ec
tune_max_num_tokens 16k -> 32k
fzyzcjy Sep 10, 2025
2890f7e
more
fzyzcjy Sep 10, 2025
dde43f6
more
fzyzcjy Sep 10, 2025
671d1cd
more
fzyzcjy Sep 10, 2025
0158022
fix instasll err
fzyzcjy Sep 16, 2025
cc512e5
Merge branch 'main-upstream' into feat/bench_cutlass_moe
fzyzcjy Sep 17, 2025
0bd7a7f
Merge branch 'feat/hack_license' into feat/bench_cutlass_moe
fzyzcjy Sep 17, 2025
20da055
more
fzyzcjy Sep 17, 2025
99f5ff2
more
fzyzcjy Sep 17, 2025
622f8ae
more
fzyzcjy Sep 17, 2025
eb55ef1
hack: mask some selected experts
fzyzcjy Sep 18, 2025
723fea7
fix tune_max_num_tokens
fzyzcjy Sep 18, 2025
ec9cab1
hack findTotalEltsLessThanTarget
fzyzcjy Sep 18, 2025
29d4170
more
fzyzcjy Sep 18, 2025
20a361b
more
fzyzcjy Sep 18, 2025
d66ef25
more
fzyzcjy Sep 18, 2025
f31b592
writeSF
fzyzcjy Sep 19, 2025
480057d
pragma unroll
fzyzcjy Sep 19, 2025
742521e
Revert "writeSF"
fzyzcjy Sep 19, 2025
5cb9936
Revert "pragma unroll"
fzyzcjy Sep 19, 2025
5d7e4d3
redo pragma unroll
fzyzcjy Sep 19, 2025
d29f4f9
make hidden size const for everywhere
fzyzcjy Sep 19, 2025
8543c83
inter_size constexpr for activation
fzyzcjy Sep 19, 2025
efb5214
unroll permute copy
fzyzcjy Sep 19, 2025
920c1ac
change unroll
fzyzcjy Sep 19, 2025
b807967
fix wrong return
fzyzcjy Sep 19, 2025
06d5fc9
simp input_sf
fzyzcjy Sep 19, 2025
93f8690
try change order
fzyzcjy Sep 19, 2025
18676b4
hack (should revert): temp rm padding
fzyzcjy Sep 19, 2025
977b8ae
Revert "hack (should revert): temp rm padding"
fzyzcjy Sep 19, 2025
ac807d7
prefetch unpermuted_row
fzyzcjy Sep 19, 2025
ad285aa
temp hack: EXPAND_THREADS_PER_BLOCK 256->128
fzyzcjy Sep 19, 2025
865e758
temp hack: EXPAND_THREADS_PER_BLOCK 256->32
fzyzcjy Sep 19, 2025
ae5376f
Revert "temp hack: EXPAND_THREADS_PER_BLOCK 256->32"
fzyzcjy Sep 19, 2025
c3d3e9f
temp hack: EXPAND_THREADS_PER_BLOCK=128 + blocks=x2
fzyzcjy Sep 19, 2025
0910742
bench kineto
fzyzcjy Sep 19, 2025
b15fa24
chore: rm log
fzyzcjy Sep 19, 2025
516bf70
chore: log
fzyzcjy Sep 19, 2025
630b482
chore: more tests
fzyzcjy Sep 19, 2025
54757d7
revert kernel to 09:10
fzyzcjy Sep 19, 2025
ab88987
chore bench
fzyzcjy Sep 19, 2025
704c1c5
hack: only enable "thread/=2, block*=2"
fzyzcjy Sep 19, 2025
ed6178c
hack: 64thread
fzyzcjy Sep 19, 2025
bc48740
Revert "hack: 64thread"
fzyzcjy Sep 19, 2025
e5b01db
unroll topk in unpermute kernel
fzyzcjy Sep 19, 2025
4106c12
unpermute use AlignedArray
fzyzcjy Sep 19, 2025
8c57702
Revert "unpermute use AlignedArray"
fzyzcjy Sep 19, 2025
a6c4d38
hack: manual vectorize
fzyzcjy Sep 19, 2025
1fddc58
more manual vectorize
fzyzcjy Sep 19, 2025
09b5d01
Revert "more manual vectorize"
fzyzcjy Sep 19, 2025
5f76d06
Revert "hack: manual vectorize"
fzyzcjy Sep 19, 2025
d7631b1
hack: unpermute, maxnreg=32
fzyzcjy Sep 19, 2025
1c6342c
Revert "hack: unpermute, maxnreg=32"
fzyzcjy Sep 19, 2025
6ba3ab4
mv load ordering
fzyzcjy Sep 19, 2025
6bcaa69
Revert "mv load ordering"
fzyzcjy Sep 19, 2025
b1bb205
make orig_cols constexpr
fzyzcjy Sep 19, 2025
6ca1b1d
hack rm trap
fzyzcjy Sep 19, 2025
fec2f44
Revert "hack rm trap"
fzyzcjy Sep 19, 2025
44d390e
Revert "make orig_cols constexpr"
fzyzcjy Sep 19, 2025
14d26e8
cp: vectorize
fzyzcjy Sep 19, 2025
4480029
cp: mv load ordering
fzyzcjy Sep 19, 2025
80b704e
hack: rm bias handling (incorrect? why is it used?)
fzyzcjy Sep 19, 2025
f954a1f
hack: read 8 - compute 8, instead of 8x(read 1 compute 1)
fzyzcjy Sep 19, 2025
9603cf7
naive handle enable_input_buf
fzyzcjy Sep 19, 2025
2e52a49
enable_input_buf use bitwise op
fzyzcjy Sep 19, 2025
1c75b9b
hack: unpermute, maxnreg=64
fzyzcjy Sep 19, 2025
4e40624
doActivationKernel reg=32
fzyzcjy Sep 19, 2025
ea5a594
Revert "doActivationKernel reg=32"
fzyzcjy Sep 19, 2025
a5d4189
hack: acti blocks 8->6
fzyzcjy Sep 19, 2025
b52391a
Revert "hack: acti blocks 8->6"
fzyzcjy Sep 19, 2025
7853d15
hack: acti - infinite num blocks
fzyzcjy Sep 19, 2025
f670fa4
hack: acti - mid num blocks
fzyzcjy Sep 19, 2025
50378ba
Revert "hack: acti - mid num blocks"
fzyzcjy Sep 19, 2025
9d1456a
Revert "hack: acti - infinite num blocks"
fzyzcjy Sep 19, 2025
ec89f0c
temp rm all
fzyzcjy Sep 19, 2025
783120b
change test
fzyzcjy Sep 19, 2025
7b0f471
Revert "temp rm all"
fzyzcjy Sep 19, 2025
5873bb3
ARR_LENGTH_CONST
fzyzcjy Sep 19, 2025
bb7a97a
Revert "ARR_LENGTH_CONST"
fzyzcjy Sep 19, 2025
8c719e6
hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len
fzyzcjy Sep 19, 2025
fece864
hack: unroll(4)
fzyzcjy Sep 19, 2025
b9bb8c7
Revert "hack: unroll(4)"
fzyzcjy Sep 19, 2025
7e6c876
Revert "hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len"
fzyzcjy Sep 19, 2025
06003aa
hack NUM_EXPERTS_PER_NODE_CONST
fzyzcjy Sep 19, 2025
afa5a61
temp: 4gpu bench
fzyzcjy Sep 19, 2025
1583eb0
fix
fzyzcjy Sep 19, 2025
3704820
temp rm all
fzyzcjy Sep 19, 2025
3a74536
partial cp
fzyzcjy Sep 19, 2025
348a536
cp change-block-thread, pragma-unroll, mv-if-check
fzyzcjy Sep 19, 2025
0e62fa8
Revert "cp change-block-thread, pragma-unroll, mv-if-check"
fzyzcjy Sep 19, 2025
941b68c
enable all except for 21:06
fzyzcjy Sep 19, 2025
47f3d50
Merge remote-tracking branch 'upstream/main' into feat/speedup_moe
fzyzcjy Sep 19, 2025
d504c61
Revert "feat: Benchmark mm_fp4 mxfp4 support and gemm autotune suppor…
fzyzcjy Sep 20, 2025
822ae9b
Revert "Enabled alpha with the mx_fp4 format (#1688)"
fzyzcjy Sep 20, 2025
299caf5
enable change-thread-block
fzyzcjy Sep 20, 2025
d83a3cb
enable mv-unpermuted_row_to_permuted_row
fzyzcjy Sep 20, 2025
68b8e6d
Revert "enable mv-unpermuted_row_to_permuted_row"
fzyzcjy Sep 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions benchmarks/bench_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

num_ranks = 2

test_configs = [
# {
# "hidden_size": 7168,
# "num_experts": 256,
# "top_k": 8,
# "intermediate_size": 256,
# },
{
"hidden_size": 7168,
"num_experts": 256,
"top_k": 8,
"intermediate_size": 256,
},
{
"hidden_size": 7168,
"num_experts": 32,
"num_experts": num_experts,
"top_k": 8,
"intermediate_size": 2048,
},
}
for num_experts in [
256 // num_ranks,
]
]


Expand Down Expand Up @@ -131,6 +135,13 @@ def bench_cutlass_fused_moe(
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)

if 1:
print("HACK: mask some selected_experts")
selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999

tune_max_num_tokens = batch_size
print(f"HACK: {tune_max_num_tokens=}")

Comment on lines +138 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This if 1: block appears to be temporary code for debugging and testing, as indicated by the "HACK" print statements. This block, including the hardcoded logic for masking experts and setting tune_max_num_tokens, should be removed before merging.

flash_output = torch.zeros_like(x)

quant_scales = [
Expand All @@ -143,6 +154,7 @@ def bench_cutlass_fused_moe(
]
hidden_states = x
hidden_states, input_sf = fp4_quantize(x, a1_gs)
print(f"{hidden_states.shape=}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes and should be removed before merging.


# Warmup
for _ in range(3):
Expand All @@ -156,7 +168,7 @@ def bench_cutlass_fused_moe(
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
tune_max_num_tokens=16384,
tune_max_num_tokens=tune_max_num_tokens,
)

if not skip_autotune:
Expand All @@ -171,10 +183,20 @@ def bench_cutlass_fused_moe(
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
tune_max_num_tokens=16384,
tune_max_num_tokens=tune_max_num_tokens,
)
ms_list = bench_gpu_time(
lambda: fused_moe.cutlass_fused_moe(

counter = 0

def f():
nonlocal counter
counter += 1

if counter == 10:
print("hi call cudaProfilerStart")
torch.cuda.cudart().cudaProfilerStart()

fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
Expand All @@ -184,8 +206,13 @@ def bench_cutlass_fused_moe(
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
),
)
)

if counter == 10:
print("hi call cudaProfilerStop")
torch.cuda.cudart().cudaProfilerStop()

ms_list = bench_gpu_time(f)
median_ms = np.median(ms_list)
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}")
print(
Expand All @@ -201,7 +228,7 @@ def bench_cutlass_fused_moe(
help="Update the config file with the new profiling results",
)
parser.add_argument(
"--num-tokens", type=int, default=32, help="Number of tokens to profile"
"--num-tokens", type=int, default=32768 * num_ranks, help="Number of tokens to profile"
)
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning")
args = parser.parse_args()
Expand Down
45 changes: 44 additions & 1 deletion csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken(
// ============================== Infer GEMM sizes =================================
// TODO Could linear search be better for small # experts
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices,
int64_t const arr_length, T const target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
Expand All @@ -881,6 +881,49 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
return target_location + 1;
}

template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) {
constexpr int ARR_LENGTH_CONST = 128;
if (arr_length != ARR_LENGTH_CONST) {
asm("trap;");
}
Comment on lines 886 to 889
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function findTotalEltsLessThanTarget_v2 has a hardcoded ARR_LENGTH_CONST and uses asm("trap;") if the input array length does not match. This makes the function non-generic and unsafe for general use, as it will cause a crash for any other input size. This experimental implementation should be made more robust or removed if it's not ready for production.


constexpr unsigned full_mask = 0xffffffffu;
constexpr int WARP_SZ = 32;
const int lane_id = threadIdx.x & (WARP_SZ - 1);

int local_count = 0;
#pragma unroll
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
const int idx = lane_id + k * WARP_SZ;
T v = sorted_indices[idx];
local_count += (v < target) ? 1 : 0;
}

#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
local_count += __shfl_down_sync(full_mask, local_count, offset);
}
int total = __shfl_sync(full_mask, local_count, 0);

return (int64_t)total;
}

template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) {
return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);

// return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);

// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
// if (out_v1 != out_v2) {
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
// asm("trap;");
// }
// return out_v1;
}
Comment on lines 913 to 925
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function contains a significant amount of commented-out debug code, including printf statements and an asm("trap;"). This should be cleaned up and removed before merging.


template <class T>
using sizeof_bits = cutlass::sizeof_bits<
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
Expand Down
5 changes: 5 additions & 0 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ def choose_one(
logger.debug(
f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}"
)
else:
# NOTE ADD
logger.debug(
f"[AutoTunner]: HACK ADD cache hit {custom_op=} {input_shapes=}"
)
return runner, tactic
Comment on lines +450 to 455
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This else block with the "HACK ADD" debug log statement appears to be temporary and should be removed.


assert len(runners) > 0, "At least one runner is required"
Expand Down
6 changes: 6 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,12 @@ def cutlass_fused_moe(
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)

print(
"hi flashinfer cutlass_fused_moe "
f"{input.shape=} {input.dtype=} "
f"{token_selected_experts.shape=}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging and should be removed before merging.


num_rows = input.shape[0]
if min_latency_mode:
num_rows *= fc2_expert_weights.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ name = "flashinfer-python"
description = "FlashInfer: Kernel Library for LLM Serving"
requires-python = ">=3.9,<4.0"
authors = [{ name = "FlashInfer team" }]
license = "Apache-2.0"
#license = "Apache-2.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Commenting out the license field is a critical issue. A project's license must be clearly declared in its metadata. Please restore this line.

Suggested change
#license = "Apache-2.0"
license = "Apache-2.0"

readme = "README.md"
urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" }
dynamic = ["dependencies", "version"]
license-files = ["LICENSE", "licenses/*"]
#license-files = ["LICENSE", "licenses/*"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Commenting out license-files is a critical issue. Please restore this line to ensure license files are included in the package.

Suggested change
#license-files = ["LICENSE", "licenses/*"]
license-files = ["LICENSE", "licenses/*"]


[build-system]
requires = ["setuptools>=77", "packaging>=24"]
Expand Down
Loading