From 498027c5957f65a06d41c4bb6a4fda639835b932 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 09:11:50 +0800 Subject: [PATCH 001/103] more --- benchmarks/bench_cutlass_fused_moe.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index e0dff8e215..62b8d66f31 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -31,15 +31,15 @@ test_configs = [ + # { + # "hidden_size": 7168, + # "num_experts": 256, + # "top_k": 8, + # "intermediate_size": 256, + # }, { "hidden_size": 7168, - "num_experts": 256, - "top_k": 8, - "intermediate_size": 256, - }, - { - "hidden_size": 7168, - "num_experts": 32, + "num_experts": 288, # 1gpu "top_k": 8, "intermediate_size": 2048, }, @@ -201,7 +201,7 @@ def bench_cutlass_fused_moe( help="Update the config file with the new profiling results", ) parser.add_argument( - "--num-tokens", type=int, default=32, help="Number of tokens to profile" + "--num-tokens", type=int, default=16384, help="Number of tokens to profile" ) parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() From 8b4c44beab2b76f5e5fe68f92578ed72a16e9fb6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 09:13:41 +0800 Subject: [PATCH 002/103] more --- benchmarks/bench_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 62b8d66f31..b5151ed275 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -201,7 +201,7 @@ def bench_cutlass_fused_moe( help="Update the config file with the new profiling results", ) parser.add_argument( - "--num-tokens", type=int, default=16384, help="Number of tokens to profile" + "--num-tokens", type=int, default=32768, help="Number of tokens to profile" ) parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() From db646ac8891e2c9ab5db4b3108818199962bc269 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 09:27:12 +0800 Subject: [PATCH 003/103] more --- benchmarks/bench_cutlass_fused_moe.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index b5151ed275..c49c7affdf 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -173,8 +173,17 @@ def bench_cutlass_fused_moe( output=flash_output, tune_max_num_tokens=16384, ) - ms_list = bench_gpu_time( - lambda: fused_moe.cutlass_fused_moe( + + counter = 0 + + def f(): + nonlocal counter + counter += 1 + + if counter == 20: + torch.cuda.cudart().cudaProfilerStart() + + fused_moe.cutlass_fused_moe( hidden_states, selected_experts.to(torch.int), routing_weights, @@ -184,8 +193,12 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - ), - ) + ) + + if counter == 20: + torch.cuda.cudart().cudaProfilerStop() + + ms_list = bench_gpu_time(f) median_ms = np.median(ms_list) print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}") print( From 9c5e3cac710928f24773e81d2ce152e4c25ab801 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 09:40:46 +0800 Subject: [PATCH 004/103] temp 4gpu --- benchmarks/bench_cutlass_fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index c49c7affdf..365903a355 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -39,7 +39,8 @@ # }, { "hidden_size": 7168, - "num_experts": 288, # 1gpu + # "num_experts": 288, # 1gpu + "num_experts": 288 // 4, # 4gpu "top_k": 8, "intermediate_size": 2048, }, From fd647484611c391331fa4031a4e99526d7b6da7b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 10:02:41 +0800 Subject: [PATCH 005/103] more --- benchmarks/bench_cutlass_fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 365903a355..6aaeb29ae9 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -39,8 +39,8 @@ # }, { "hidden_size": 7168, - # "num_experts": 288, # 1gpu - "num_experts": 288 // 4, # 4gpu + # "num_experts": 256, # 1gpu + "num_experts": 256 // 4, # 4gpu "top_k": 8, "intermediate_size": 2048, }, From 145efc4e294c69825e6555c33fe5eff7e26fa00a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 10:27:53 +0800 Subject: [PATCH 006/103] more --- benchmarks/bench_cutlass_fused_moe.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 6aaeb29ae9..f251efc626 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -39,11 +39,15 @@ # }, { "hidden_size": 7168, - # "num_experts": 256, # 1gpu - "num_experts": 256 // 4, # 4gpu + "num_experts": num_experts, "top_k": 8, "intermediate_size": 2048, - }, + } + for num_experts in [ + 256, + 256 // 2, + 256 // 4, + ] ] From ed240ab147049a31fd05314d30bbae5a01f1ba7f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 10:31:30 +0800 Subject: [PATCH 007/103] more --- benchmarks/bench_cutlass_fused_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index f251efc626..ce1d778b7a 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -47,6 +47,9 @@ 256, 256 // 2, 256 // 4, + 256 // 8, + 256 // 16, + 256 // 32, ] ] From 9178d57fe9a278d00a9da0334226367bc8238540 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 10:33:38 +0800 Subject: [PATCH 008/103] temp --- benchmarks/bench_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index ce1d778b7a..e58fac8a1a 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -222,7 +222,7 @@ def f(): help="Update the config file with the new profiling results", ) parser.add_argument( - "--num-tokens", type=int, default=32768, help="Number of tokens to profile" + "--num-tokens", type=int, default=16384, help="Number of tokens to profile" ) parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() From 27d90532caec3aca0709c837cf8c7b8454a9d90f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 10:34:08 +0800 Subject: [PATCH 009/103] Revert "temp" This reverts commit 9178d57fe9a278d00a9da0334226367bc8238540. --- benchmarks/bench_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index e58fac8a1a..ce1d778b7a 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -222,7 +222,7 @@ def f(): help="Update the config file with the new profiling results", ) parser.add_argument( - "--num-tokens", type=int, default=16384, help="Number of tokens to profile" + "--num-tokens", type=int, default=32768, help="Number of tokens to profile" ) parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() From d39f3ecf4f8b3561e2112b904b59ca17809ba703 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 13:50:37 +0800 Subject: [PATCH 010/103] tune_max_num_tokens 16k -> 32k --- benchmarks/bench_cutlass_fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index ce1d778b7a..451cab3231 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -164,7 +164,7 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - tune_max_num_tokens=16384, + tune_max_num_tokens=32768, ) if not skip_autotune: @@ -179,7 +179,7 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - tune_max_num_tokens=16384, + tune_max_num_tokens=32768, ) counter = 0 From 2890f7e8f8c9f1321a306905499dbcba15eab2e1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 14:34:11 +0800 Subject: [PATCH 011/103] more --- flashinfer/autotuner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 32bf52d113..416cddd848 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -447,6 +447,11 @@ def choose_one( logger.debug( f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}" ) + else: + # NOTE ADD + logger.debug( + f"[AutoTunner]: HACK ADD cache hit {custom_op=} {input_shapes=}" + ) return runner, tactic assert len(runners) > 0, "At least one runner is required" From dde43f6a56fd173bed5dede8ac14fe22b62287ca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 14:35:38 +0800 Subject: [PATCH 012/103] more --- flashinfer/fused_moe/core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index ad6169c515..3063ca5691 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1988,6 +1988,13 @@ def trtllm_fp4_block_scale_routed_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + + print( + "hi trtllm_fp4_block_scale_routed_moe " + f"{hidden_states.shape=} {hidden_states.dtype=} " + f"{topk_ids.shape=}" + ) + return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( None, topk_ids, From 671d1cda54dd50606d5be30eeb0f8f9b032a5060 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 10 Sep 2025 14:40:17 +0800 Subject: [PATCH 013/103] more --- flashinfer/fused_moe/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 3063ca5691..15d9b5f477 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -870,6 +870,12 @@ def cutlass_fused_moe( if enable_pdl is None: enable_pdl = device_support_pdl(input.device) + print( + "hi flashinfer cutlass_fused_moe " + f"{input.shape=} {input.dtype=} " + f"{token_selected_experts.shape=}" + ) + num_rows = input.shape[0] if min_latency_mode: num_rows *= fc2_expert_weights.shape[0] @@ -1988,13 +1994,6 @@ def trtllm_fp4_block_scale_routed_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ - - print( - "hi trtllm_fp4_block_scale_routed_moe " - f"{hidden_states.shape=} {hidden_states.dtype=} " - f"{topk_ids.shape=}" - ) - return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( None, topk_ids, From 01580228db9f5ed8463ceae06c5b508792d337de Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 16 Sep 2025 17:44:02 +0800 Subject: [PATCH 014/103] fix instasll err --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9c6f1ffc22..120a631929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,11 @@ name = "flashinfer-python" description = "FlashInfer: Kernel Library for LLM Serving" requires-python = ">=3.9,<4.0" authors = [{ name = "FlashInfer team" }] -license = "Apache-2.0" +#license = "Apache-2.0" readme = "README.md" urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" } dynamic = ["dependencies", "version"] -license-files = ["LICENSE", "licenses/*"] +#license-files = ["LICENSE", "licenses/*"] [build-system] requires = ["setuptools>=77", "packaging>=24"] From 20da055f273815a6b9cec8e0caf170947e512a65 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 17 Sep 2025 14:18:13 +0800 Subject: [PATCH 015/103] more --- benchmarks/bench_cutlass_fused_moe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 451cab3231..cc8781efcb 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -44,12 +44,12 @@ "intermediate_size": 2048, } for num_experts in [ - 256, + # 256, 256 // 2, - 256 // 4, - 256 // 8, - 256 // 16, - 256 // 32, + # 256 // 4, + # 256 // 8, + # 256 // 16, + # 256 // 32, ] ] From 99f5ff28eb11f2e7437fb5ee884c00a91bdae067 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 17 Sep 2025 14:20:09 +0800 Subject: [PATCH 016/103] more --- benchmarks/bench_cutlass_fused_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index cc8781efcb..86d3ac9051 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -189,6 +189,7 @@ def f(): counter += 1 if counter == 20: + print("hi call cudaProfilerStart") torch.cuda.cudart().cudaProfilerStart() fused_moe.cutlass_fused_moe( @@ -204,6 +205,7 @@ def f(): ) if counter == 20: + print("hi call cudaProfilerStop") torch.cuda.cudart().cudaProfilerStop() ms_list = bench_gpu_time(f) From 622f8aeb8c34cbd282ef8a3cbcb34e3b3b6123be Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 17 Sep 2025 14:20:51 +0800 Subject: [PATCH 017/103] more --- benchmarks/bench_cutlass_fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 86d3ac9051..0ed6a8de90 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -188,7 +188,7 @@ def f(): nonlocal counter counter += 1 - if counter == 20: + if counter == 10: print("hi call cudaProfilerStart") torch.cuda.cudart().cudaProfilerStart() @@ -204,7 +204,7 @@ def f(): output=flash_output, ) - if counter == 20: + if counter == 10: print("hi call cudaProfilerStop") torch.cuda.cudart().cudaProfilerStop() From eb55ef1bdfcb3688108a55068dbe2445d9ca6f4c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 18 Sep 2025 22:13:29 +0800 Subject: [PATCH 018/103] hack: mask some selected experts --- benchmarks/bench_cutlass_fused_moe.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 0ed6a8de90..6961ec7230 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -29,6 +29,7 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +num_ranks = 2 test_configs = [ # { @@ -44,12 +45,7 @@ "intermediate_size": 2048, } for num_experts in [ - # 256, - 256 // 2, - # 256 // 4, - # 256 // 8, - # 256 // 16, - # 256 // 32, + 256 // num_ranks, ] ] @@ -139,6 +135,10 @@ def bench_cutlass_fused_moe( router_logits = torch.randn(m, e, dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) + if 1: + print("HACK: mask some selected_experts") + selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999 + flash_output = torch.zeros_like(x) quant_scales = [ @@ -151,6 +151,7 @@ def bench_cutlass_fused_moe( ] hidden_states = x hidden_states, input_sf = fp4_quantize(x, a1_gs) + print(f"{hidden_states.shape=}") # Warmup for _ in range(3): @@ -224,7 +225,7 @@ def f(): help="Update the config file with the new profiling results", ) parser.add_argument( - "--num-tokens", type=int, default=32768, help="Number of tokens to profile" + "--num-tokens", type=int, default=32768 * num_ranks, help="Number of tokens to profile" ) parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() From 723fea77282b4271c016ff2b430e9966df17f46c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 18 Sep 2025 22:15:51 +0800 Subject: [PATCH 019/103] fix tune_max_num_tokens --- benchmarks/bench_cutlass_fused_moe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 6961ec7230..39c53cd3cd 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -139,6 +139,9 @@ def bench_cutlass_fused_moe( print("HACK: mask some selected_experts") selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999 + tune_max_num_tokens = batch_size + print(f"HACK: {tune_max_num_tokens=}") + flash_output = torch.zeros_like(x) quant_scales = [ @@ -165,7 +168,7 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - tune_max_num_tokens=32768, + tune_max_num_tokens=tune_max_num_tokens, ) if not skip_autotune: @@ -180,7 +183,7 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - tune_max_num_tokens=32768, + tune_max_num_tokens=tune_max_num_tokens, ) counter = 0 From ec9cab1ff13177b67c38cf487a91beb859fd8eeb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 18 Sep 2025 23:01:14 +0800 Subject: [PATCH 020/103] hack findTotalEltsLessThanTarget --- .../cutlass_fused_moe_kernels.cuh | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d94bc69b23..71a3fd53ba 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -865,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken( // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, +__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, int64_t const arr_length, T const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { @@ -881,6 +881,47 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, return target_location + 1; } +template +__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { + constexpr int ARR_LENGTH_CONST = 128; + if (arr_length != ARR_LENGTH_CONST) { + asm("trap;"); + } + + constexpr unsigned full_mask = 0xffffffffu; + constexpr int WARP_SZ = 32; + const int lane_id = threadIdx.x & (WARP_SZ - 1); + + int local_count = 0; +#pragma unroll + for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { + const int idx = lane_id + k * WARP_SZ; + T v = sorted_indices[idx]; + local_count += (v < target) ? 1 : 0; + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(full_mask, local_count, offset); + } + int total = __shfl_sync(full_mask, local_count, 0); + + return (int64_t)total; +} + +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + +// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); +// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); +// if (out_v1 != out_v2) { +// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); +// asm("trap"); +// } +// return out_v1; +} + template using sizeof_bits = cutlass::sizeof_bits< typename cutlass_kernels::TllmToCutlassTypeAdapter>::type>; From 29d41704ded8e1c5ac2f4e266e1db75c4ece3b49 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 18 Sep 2025 23:05:51 +0800 Subject: [PATCH 021/103] more --- .../cutlass_fused_moe_kernels.cuh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 71a3fd53ba..8972e593e7 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -911,15 +911,15 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); - -// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); -// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); -// if (out_v1 != out_v2) { -// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); -// asm("trap"); -// } -// return out_v1; +// return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + + int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); + int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + if (out_v1 != out_v2) { + printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); + asm("trap;"); + } + return out_v1; } template From 20a361be86bf226e881d505cad03994937a18f62 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 18 Sep 2025 23:08:06 +0800 Subject: [PATCH 022/103] more --- .../cutlass_fused_moe_kernels.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 8972e593e7..2e5f792d0a 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -911,15 +911,17 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { + return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); + // return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); - int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); - if (out_v1 != out_v2) { - printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); - asm("trap;"); - } - return out_v1; +// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); +// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); +// if (out_v1 != out_v2) { +// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); +// asm("trap;"); +// } +// return out_v1; } template From d66ef257ed5fae804683abf94bccf030d9a80264 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 07:49:18 +0800 Subject: [PATCH 023/103] more --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 2e5f792d0a..7ba843b06b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -911,9 +911,9 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { - return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); +// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); -// return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); From f31b59238fd593eb015d655891133af46d33eb32 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 08:43:44 +0800 Subject: [PATCH 024/103] writeSF --- .../cutlass_fused_moe_kernels.cuh | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7ba843b06b..5549053ad3 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1082,6 +1082,71 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, } } +template +__device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t num_tokens_before_expert, int64_t expert_id, + int64_t source_token_id, int64_t token_id, int64_t elem_idx, + int64_t num_cols, + TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { + static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; + + // We need to offset into the scaling factors for just this expert + auto act_sf_expert = + act_sf_flat + getOffsetActivationSF( + expert_id, num_tokens_before_expert, num_cols, + (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + + // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this + // expert + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); + if (sf_out) { + if (input_sf) { + auto const sf_in = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED_128x4); + return *sf_in; + } else { + return 0x00; + } + } +} + +template +__device__ void writeSF_v2_write(int64_t num_tokens_before_expert, int64_t expert_id, + int64_t source_token_id, int64_t token_id, int64_t elem_idx, + int64_t num_cols, + TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF value_to_write) { + static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; + + // We need to offset into the scaling factors for just this expert + auto act_sf_expert = + act_sf_flat + getOffsetActivationSF( + expert_id, num_tokens_before_expert, num_cols, + (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + + // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this + // expert + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); + if (sf_out) { + *sf_out = value_to_write; + } +} + // ====================== Compute FP8 dequant scale only =============================== __global__ void computeFP8DequantScaleKernel(float const** alpha_scale_ptr_array, int64_t const num_experts_per_node, @@ -1577,9 +1642,13 @@ __global__ void expandInputRowsKernel( } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - writeSF(num_tokens_before_expert, expert, source_row, + TmaWarpSpecializedGroupedGemmInput::ElementSF sf_value = + writeSF_v2_read(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); + writeSF_v2_write(num_tokens_before_expert, expert, source_row, + permuted_row, elem_index, padded_hidden_size, + fc1_act_sf_flat, sf_value); dest_row_ptr[elem_index] = in_vec; } } From 480057d7254487731e82b8b86763ccca86488fba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 08:47:12 +0800 Subject: [PATCH 025/103] pragma unroll --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 5549053ad3..83277d6b02 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1611,7 +1611,7 @@ __global__ void expandInputRowsKernel( permuted_row * hidden_size / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; - int64_t const stride = EXPAND_THREADS_PER_BLOCK; + constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); @@ -1627,7 +1627,11 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { + constexpr int NUM_ELEMS_IN_COL_CONST = 7168 / 8; + assert(num_elems_in_col == NUM_ELEMS_IN_COL_CONST); + +#pragma unroll + for (int elem_index = start_offset; elem_index < NUM_ELEMS_IN_COL_CONST; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; if constexpr (need_nvfp4_quant || need_mxfp8_quant) { auto res = quantizePackedFPXValue Date: Fri, 19 Sep 2025 08:47:20 +0800 Subject: [PATCH 026/103] Revert "writeSF" This reverts commit f31b59238fd593eb015d655891133af46d33eb32. --- .../cutlass_fused_moe_kernels.cuh | 71 +------------------ 1 file changed, 1 insertion(+), 70 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 83277d6b02..402b6ff67f 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1082,71 +1082,6 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, } } -template -__device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t num_tokens_before_expert, int64_t expert_id, - int64_t source_token_id, int64_t token_id, int64_t elem_idx, - int64_t num_cols, - TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { - static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; - - // We need to offset into the scaling factors for just this expert - auto act_sf_expert = - act_sf_flat + getOffsetActivationSF( - expert_id, num_tokens_before_expert, num_cols, - (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) - ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); - - // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this - // expert - auto sf_out = - cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); - if (sf_out) { - if (input_sf) { - auto const sf_in = cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - return *sf_in; - } else { - return 0x00; - } - } -} - -template -__device__ void writeSF_v2_write(int64_t num_tokens_before_expert, int64_t expert_id, - int64_t source_token_id, int64_t token_id, int64_t elem_idx, - int64_t num_cols, - TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF value_to_write) { - static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; - - // We need to offset into the scaling factors for just this expert - auto act_sf_expert = - act_sf_flat + getOffsetActivationSF( - expert_id, num_tokens_before_expert, num_cols, - (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) - ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); - - // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this - // expert - auto sf_out = - cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); - if (sf_out) { - *sf_out = value_to_write; - } -} - // ====================== Compute FP8 dequant scale only =============================== __global__ void computeFP8DequantScaleKernel(float const** alpha_scale_ptr_array, int64_t const num_experts_per_node, @@ -1646,13 +1581,9 @@ __global__ void expandInputRowsKernel( } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - TmaWarpSpecializedGroupedGemmInput::ElementSF sf_value = - writeSF_v2_read(num_tokens_before_expert, expert, source_row, + writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); - writeSF_v2_write(num_tokens_before_expert, expert, source_row, - permuted_row, elem_index, padded_hidden_size, - fc1_act_sf_flat, sf_value); dest_row_ptr[elem_index] = in_vec; } } From 5cb9936e54f1c3228a0bef6f5041d4876d2b4580 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 08:50:40 +0800 Subject: [PATCH 027/103] Revert "pragma unroll" This reverts commit 480057d7254487731e82b8b86763ccca86488fba. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 402b6ff67f..7ba843b06b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1546,7 +1546,7 @@ __global__ void expandInputRowsKernel( permuted_row * hidden_size / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; - constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; + int64_t const stride = EXPAND_THREADS_PER_BLOCK; int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); @@ -1562,11 +1562,7 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; - constexpr int NUM_ELEMS_IN_COL_CONST = 7168 / 8; - assert(num_elems_in_col == NUM_ELEMS_IN_COL_CONST); - -#pragma unroll - for (int elem_index = start_offset; elem_index < NUM_ELEMS_IN_COL_CONST; elem_index += stride) { + for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; if constexpr (need_nvfp4_quant || need_mxfp8_quant) { auto res = quantizePackedFPXValue Date: Fri, 19 Sep 2025 08:54:36 +0800 Subject: [PATCH 028/103] redo pragma unroll --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7ba843b06b..6e0acb5a98 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1545,9 +1545,12 @@ __global__ void expandInputRowsKernel( auto* dest_row_ptr = reinterpret_cast(permuted_output) + permuted_row * hidden_size / ELEM_PER_THREAD; + constexpr int HIDDEN_SIZE_CONST = 7168; + if (hidden_size != HIDDEN_SIZE_CONST) { asm("trap;"); } + int64_t const start_offset = threadIdx.x; - int64_t const stride = EXPAND_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; + constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; + constexpr int64_t num_elems_in_col = HIDDEN_SIZE_CONST / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); @@ -1562,6 +1565,7 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; +#pragma unroll for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; if constexpr (need_nvfp4_quant || need_mxfp8_quant) { From d29f4f927abbc59a57246ab3ec5fa1bb424826ac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:04:29 +0800 Subject: [PATCH 029/103] make hidden size const for everywhere --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 6e0acb5a98..30c37817db 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1465,12 +1465,15 @@ template (permuted_output) + permuted_row * hidden_size / ELEM_PER_THREAD; - constexpr int HIDDEN_SIZE_CONST = 7168; - if (hidden_size != HIDDEN_SIZE_CONST) { asm("trap;"); } - int64_t const start_offset = threadIdx.x; constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; - constexpr int64_t num_elems_in_col = HIDDEN_SIZE_CONST / ELEM_PER_THREAD; + constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); From 8543c834b5c2eaf7f4941919b2bb21d74a858ef3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:08:46 +0800 Subject: [PATCH 030/103] inter_size constexpr for activation --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 30c37817db..9711f7bf49 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2177,10 +2177,13 @@ template && @@ -2265,9 +2268,9 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto output_vec = reinterpret_cast(safe_inc_ptr(output, output_offset)); auto bias_ptr_vec = reinterpret_cast(bias_ptr + bias_offset); int64_t const start_offset = tid; - int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; + constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK; assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0); - int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; + constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; @@ -2275,6 +2278,8 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fn.alpha = gate_alpha; fn.beta = gate_beta; fn.limit = gate_limit; + +#pragma unroll for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert(gemm_result_vec[elem_index + gated_off_vec]); From efb52146d346eae4d40a0c2056c5db467d3bb657 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:20:53 +0800 Subject: [PATCH 031/103] unroll permute copy --- .../cutlass_fused_moe_kernels.cuh | 127 +++++++++++++++--- 1 file changed, 110 insertions(+), 17 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9711f7bf49..8c54832511 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1082,6 +1082,71 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, } } +template +__device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t num_tokens_before_expert, int64_t expert_id, + int64_t source_token_id, int64_t token_id, int64_t elem_idx, + int64_t num_cols, + TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { + static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; + + // We need to offset into the scaling factors for just this expert + auto act_sf_expert = + act_sf_flat + getOffsetActivationSF( + expert_id, num_tokens_before_expert, num_cols, + (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + + // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this + // expert + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); + if (sf_out) { + if (input_sf) { + auto const sf_in = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED_128x4); + return *sf_in; + } else { + return 0x00; + } + } +} + +template +__device__ void writeSF_v2_write(int64_t num_tokens_before_expert, int64_t expert_id, + int64_t source_token_id, int64_t token_id, int64_t elem_idx, + int64_t num_cols, + TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF value_to_write) { + static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; + + // We need to offset into the scaling factors for just this expert + auto act_sf_expert = + act_sf_flat + getOffsetActivationSF( + expert_id, num_tokens_before_expert, num_cols, + (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + + // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this + // expert + auto sf_out = + cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, + std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, + QuantizationSFLayout::SWIZZLED_128x4); + if (sf_out) { + *sf_out = value_to_write; + } +} + // ====================== Compute FP8 dequant scale only =============================== __global__ void computeFP8DequantScaleKernel(float const** alpha_scale_ptr_array, int64_t const num_experts_per_node, @@ -1565,29 +1630,57 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; -#pragma unroll - for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { - auto in_vec = source_row_ptr[elem_index]; - if constexpr (need_nvfp4_quant || need_mxfp8_quant) { - auto res = quantizePackedFPXValue( - in_vec, global_scale_val, num_tokens_before_expert, expert, permuted_row, elem_index, - padded_hidden_size, fc1_act_sf_flat, - is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); - static_assert(sizeof(res) == sizeof(*dest_row_ptr), - "Quantized value must be the same size as the output"); - dest_row_ptr[elem_index] = res; - } else { + if constexpr (need_nvfp4_quant || need_mxfp8_quant) { + asm("trap;"); + } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - writeSF(num_tokens_before_expert, expert, source_row, + + constexpr int BUF_SIZE = 4; + static_assert(ceilDiv(num_elems_in_col, stride) == BUF_SIZE); + DataElem data_buf[BUF_SIZE]; + TmaWarpSpecializedGroupedGemmInput::ElementSF sf_buf[BUF_SIZE]; + +#pragma unroll + for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { + data_buf[idx] = source_row_ptr[elem_index]; + sf_buf[idx] = writeSF_v2_read(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); - dest_row_ptr[elem_index] = in_vec; - } + } + +#pragma unroll + for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { + dest_row_ptr[elem_index] = data_buf[idx]; + writeSF_v2_write(num_tokens_before_expert, expert, source_row, + permuted_row, elem_index, padded_hidden_size, + fc1_act_sf_flat, sf_buf[idx]); + } } +// #pragma unroll +// for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { +// auto in_vec = source_row_ptr[elem_index]; +// if constexpr (need_nvfp4_quant || need_mxfp8_quant) { +// auto res = quantizePackedFPXValue( +// in_vec, global_scale_val, num_tokens_before_expert, expert, permuted_row, elem_index, +// padded_hidden_size, fc1_act_sf_flat, +// is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 +// : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); +// static_assert(sizeof(res) == sizeof(*dest_row_ptr), +// "Quantized value must be the same size as the output"); +// dest_row_ptr[elem_index] = res; +// } else { +// assert(act_scale_idx == 0 && +// "Cannot use per-expert act scale for pre-quantized activations"); +// writeSF(num_tokens_before_expert, expert, source_row, +// permuted_row, elem_index, padded_hidden_size, +// fc1_act_sf_flat, input_sf); +// dest_row_ptr[elem_index] = in_vec; +// } +// } + // Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan // values in the padded SF atom Use VecSize per thread since we are just writing out zeros so // every thread can process a whole vector From 920c1acf286d54971736536964f41e6bc8685a66 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:27:48 +0800 Subject: [PATCH 032/103] change unroll --- .../fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 8c54832511..2ebe7c016d 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1644,6 +1644,9 @@ __global__ void expandInputRowsKernel( #pragma unroll for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { data_buf[idx] = source_row_ptr[elem_index]; + } +#pragma unroll + for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { sf_buf[idx] = writeSF_v2_read(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); @@ -1652,6 +1655,9 @@ __global__ void expandInputRowsKernel( #pragma unroll for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { dest_row_ptr[elem_index] = data_buf[idx]; + } +#pragma unroll + for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { writeSF_v2_write(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, sf_buf[idx]); From b80796795052d95c44be69e4335bd5aaac5682bc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:29:22 +0800 Subject: [PATCH 033/103] fix wrong return --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 2ebe7c016d..5b889385e4 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1117,6 +1117,8 @@ __device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t return 0x00; } } + + return 0x00; } template From 06d5fc99fc155e60f15fe8b56e4f85251eb1c8c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:33:47 +0800 Subject: [PATCH 034/103] simp input_sf --- .../cutlass_fused_moe_kernels.cuh | 38 ++++--------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 5b889385e4..36be4c9481 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1089,36 +1089,14 @@ __device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; - - // We need to offset into the scaling factors for just this expert - auto act_sf_expert = - act_sf_flat + getOffsetActivationSF( - expert_id, num_tokens_before_expert, num_cols, - (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) - ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); - - // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this - // expert - auto sf_out = - cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); - if (sf_out) { - if (input_sf) { - auto const sf_in = cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - return *sf_in; - } else { - return 0x00; - } - } - - return 0x00; + assert(input_sf != nullptr); + // TODO correct? + auto const sf_in = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED_128x4); + return *sf_in; } template From 93f869047073cfd99ffd80d29d7c2fc017038f47 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:42:48 +0800 Subject: [PATCH 035/103] try change order --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 36be4c9481..1c0c67748c 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1624,23 +1624,17 @@ __global__ void expandInputRowsKernel( #pragma unroll for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { data_buf[idx] = source_row_ptr[elem_index]; - } -#pragma unroll - for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { sf_buf[idx] = writeSF_v2_read(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); } -#pragma unroll - for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { - dest_row_ptr[elem_index] = data_buf[idx]; - } #pragma unroll for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { writeSF_v2_write(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, sf_buf[idx]); + dest_row_ptr[elem_index] = data_buf[idx]; } } From 18676b4978a43227a34a622512f3c25beeee6ea7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:55:56 +0800 Subject: [PATCH 036/103] hack (should revert): temp rm padding --- .../cutlass_fused_moe_kernels.cuh | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 1c0c67748c..3fa1ae5865 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1702,46 +1702,46 @@ __global__ void expandInputRowsKernel( asm volatile("griddepcontrol.launch_dependents;"); #endif - // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values - // in the padded SF atom - if constexpr (is_nvfp4 || is_mxfp8) { - int64_t const start_offset = threadIdx.x; - int64_t const stride = EXPAND_THREADS_PER_BLOCK; - // Use VecSize per thread since we are just writing out zeros so every thread can process a - // whole vector - int64_t const padded_num_elems_in_col = padded_hidden_size / VecSize; - assert(padded_hidden_size % VecSize == 0); - - constexpr int min_num_tokens_alignment = - is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 - : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; - static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, - "Min num tokens alignment must be a power of two"); - // Since we don't know a priori how much padding is needed we assume the max per expert - // NOTE: we don't use (min_num_tokens_alignment-1) to be able to do power of two divisions - int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; - - for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; - padding_token += gridDim.x) { - int64_t expert = padding_token / min_num_tokens_alignment; - int64_t num_tokens_before_expert = expert_first_token_offset[expert]; - int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; - int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; - int64_t padding_to_expert = TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - tokens_to_expert, min_num_tokens_alignment) - - tokens_to_expert; - int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; - if (expert_pad_idx < padding_to_expert) { - for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; - elem_index += stride) { - writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, - num_tokens_after_expert + expert_pad_idx, elem_index, - padded_hidden_size, fc1_act_sf_flat, - /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 - } - } - } - } +// // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values +// // in the padded SF atom +// if constexpr (is_nvfp4 || is_mxfp8) { +// int64_t const start_offset = threadIdx.x; +// int64_t const stride = EXPAND_THREADS_PER_BLOCK; +// // Use VecSize per thread since we are just writing out zeros so every thread can process a +// // whole vector +// int64_t const padded_num_elems_in_col = padded_hidden_size / VecSize; +// assert(padded_hidden_size % VecSize == 0); +// +// constexpr int min_num_tokens_alignment = +// is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 +// : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; +// static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, +// "Min num tokens alignment must be a power of two"); +// // Since we don't know a priori how much padding is needed we assume the max per expert +// // NOTE: we don't use (min_num_tokens_alignment-1) to be able to do power of two divisions +// int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; +// +// for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; +// padding_token += gridDim.x) { +// int64_t expert = padding_token / min_num_tokens_alignment; +// int64_t num_tokens_before_expert = expert_first_token_offset[expert]; +// int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; +// int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; +// int64_t padding_to_expert = TmaWarpSpecializedGroupedGemmInput::alignToSfDim( +// tokens_to_expert, min_num_tokens_alignment) - +// tokens_to_expert; +// int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; +// if (expert_pad_idx < padding_to_expert) { +// for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; +// elem_index += stride) { +// writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, +// num_tokens_after_expert + expert_pad_idx, elem_index, +// padded_hidden_size, fc1_act_sf_flat, +// /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 +// } +// } +// } +// } } template From 977b8ae5f865c0c87027ff805d6a70d921e62336 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 09:57:23 +0800 Subject: [PATCH 037/103] Revert "hack (should revert): temp rm padding" This reverts commit 18676b4978a43227a34a622512f3c25beeee6ea7. --- .../cutlass_fused_moe_kernels.cuh | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 3fa1ae5865..1c0c67748c 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1702,46 +1702,46 @@ __global__ void expandInputRowsKernel( asm volatile("griddepcontrol.launch_dependents;"); #endif -// // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values -// // in the padded SF atom -// if constexpr (is_nvfp4 || is_mxfp8) { -// int64_t const start_offset = threadIdx.x; -// int64_t const stride = EXPAND_THREADS_PER_BLOCK; -// // Use VecSize per thread since we are just writing out zeros so every thread can process a -// // whole vector -// int64_t const padded_num_elems_in_col = padded_hidden_size / VecSize; -// assert(padded_hidden_size % VecSize == 0); -// -// constexpr int min_num_tokens_alignment = -// is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 -// : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; -// static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, -// "Min num tokens alignment must be a power of two"); -// // Since we don't know a priori how much padding is needed we assume the max per expert -// // NOTE: we don't use (min_num_tokens_alignment-1) to be able to do power of two divisions -// int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; -// -// for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; -// padding_token += gridDim.x) { -// int64_t expert = padding_token / min_num_tokens_alignment; -// int64_t num_tokens_before_expert = expert_first_token_offset[expert]; -// int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; -// int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; -// int64_t padding_to_expert = TmaWarpSpecializedGroupedGemmInput::alignToSfDim( -// tokens_to_expert, min_num_tokens_alignment) - -// tokens_to_expert; -// int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; -// if (expert_pad_idx < padding_to_expert) { -// for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; -// elem_index += stride) { -// writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, -// num_tokens_after_expert + expert_pad_idx, elem_index, -// padded_hidden_size, fc1_act_sf_flat, -// /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 -// } -// } -// } -// } + // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values + // in the padded SF atom + if constexpr (is_nvfp4 || is_mxfp8) { + int64_t const start_offset = threadIdx.x; + int64_t const stride = EXPAND_THREADS_PER_BLOCK; + // Use VecSize per thread since we are just writing out zeros so every thread can process a + // whole vector + int64_t const padded_num_elems_in_col = padded_hidden_size / VecSize; + assert(padded_hidden_size % VecSize == 0); + + constexpr int min_num_tokens_alignment = + is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; + static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, + "Min num tokens alignment must be a power of two"); + // Since we don't know a priori how much padding is needed we assume the max per expert + // NOTE: we don't use (min_num_tokens_alignment-1) to be able to do power of two divisions + int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; + + for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; + padding_token += gridDim.x) { + int64_t expert = padding_token / min_num_tokens_alignment; + int64_t num_tokens_before_expert = expert_first_token_offset[expert]; + int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; + int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; + int64_t padding_to_expert = TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + tokens_to_expert, min_num_tokens_alignment) - + tokens_to_expert; + int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; + if (expert_pad_idx < padding_to_expert) { + for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; + elem_index += stride) { + writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, + num_tokens_after_expert + expert_pad_idx, elem_index, + padded_hidden_size, fc1_act_sf_flat, + /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 + } + } + } + } } template From ac807d76e5570233938810b0097fa1aa11f8607c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:01:06 +0800 Subject: [PATCH 038/103] prefetch unpermuted_row --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 1c0c67748c..7663f2ad10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1566,10 +1566,18 @@ __global__ void expandInputRowsKernel( int64_t const padded_hidden_size = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(hidden_size, min_k_dim_alignment); + int64_t unpermuted_row_next = permuted_row_to_unpermuted_row[blockIdx.x]; + int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; for (int64_t permuted_row = blockIdx.x; permuted_row < num_valid_tokens; permuted_row += gridDim.x) { - int64_t const unpermuted_row = permuted_row_to_unpermuted_row[permuted_row]; + int64_t const unpermuted_row = unpermuted_row_next; + { + int64_t idx = permuted_row + gridDim.x; + if (idx < num_valid_tokens) { + unpermuted_row_next = permuted_row_to_unpermuted_row[idx]; + } + } // Load 128-bits per thread From ad285aabf86f1cc855b536f4aa16137e01cc33f1 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:13:28 +0800 Subject: [PATCH 039/103] temp hack: EXPAND_THREADS_PER_BLOCK 256->128 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7663f2ad10..ebc1d084e6 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1502,7 +1502,8 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +// constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template Date: Fri, 19 Sep 2025 10:18:51 +0800 Subject: [PATCH 040/103] temp hack: EXPAND_THREADS_PER_BLOCK 256->32 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index ebc1d084e6..7c1c466efd 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1503,7 +1503,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // source matrix, we simply take the modulus of the expanded index. // constexpr static int EXPAND_THREADS_PER_BLOCK = 256; -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 32; template Date: Fri, 19 Sep 2025 10:23:45 +0800 Subject: [PATCH 041/103] Revert "temp hack: EXPAND_THREADS_PER_BLOCK 256->32" This reverts commit 865e7582e557a9e3ca0d97627064e213a201088d. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7c1c466efd..ebc1d084e6 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1503,7 +1503,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // source matrix, we simply take the modulus of the expanded index. // constexpr static int EXPAND_THREADS_PER_BLOCK = 256; -constexpr static int EXPAND_THREADS_PER_BLOCK = 32; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template Date: Fri, 19 Sep 2025 10:25:03 +0800 Subject: [PATCH 042/103] temp hack: EXPAND_THREADS_PER_BLOCK=128 + blocks=x2 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index ebc1d084e6..e2f21f997d 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1779,7 +1779,7 @@ void expandInputRowsKernelLauncher( static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens)); int64_t const threads = EXPAND_THREADS_PER_BLOCK; auto func = [&]() { From 09107427ed16765f9f26618542fd909c250e971a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:37:08 +0800 Subject: [PATCH 043/103] bench kineto --- benchmarks/bench_cutlass_fused_moe.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 39c53cd3cd..5f019a4c64 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -219,6 +219,14 @@ def f(): f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}" ) + from flashinfer.testing.utils import bench_kineto + ts = bench_kineto( + f, + ("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), + suppress_kineto_output=False, + ) + print(f"Kineto output: {ts=}") + if __name__ == "__main__": parser = argparse.ArgumentParser() From b15fa2474cba60d71ba793a5b19d38236867fc0a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:37:53 +0800 Subject: [PATCH 044/103] chore: rm log --- flashinfer/fused_moe/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 15d9b5f477..04b40a636f 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -870,11 +870,11 @@ def cutlass_fused_moe( if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - print( - "hi flashinfer cutlass_fused_moe " - f"{input.shape=} {input.dtype=} " - f"{token_selected_experts.shape=}" - ) + # print( + # "hi flashinfer cutlass_fused_moe " + # f"{input.shape=} {input.dtype=} " + # f"{token_selected_experts.shape=}" + # ) num_rows = input.shape[0] if min_latency_mode: From 516bf709294b1cd4e59a8461902b45728e206112 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:39:23 +0800 Subject: [PATCH 045/103] chore: log --- benchmarks/bench_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 5f019a4c64..be98d6498d 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -225,7 +225,7 @@ def f(): ("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), suppress_kineto_output=False, ) - print(f"Kineto output: {ts=}") + print(f"Kineto output: ts_ms={['%.3f' % (t * 1000) for t in ts]}") if __name__ == "__main__": From 630b482ed0d83deaca3ebd46ccc4519eaf2fb1eb Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:41:12 +0800 Subject: [PATCH 046/103] chore: more tests --- benchmarks/bench_cutlass_fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index be98d6498d..b3814f3aa9 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -224,6 +224,7 @@ def f(): f, ("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), suppress_kineto_output=False, + num_tests=100, ) print(f"Kineto output: ts_ms={['%.3f' % (t * 1000) for t in ts]}") From 54757d752fb134bfd0d14d0af2dc6244e8b5e8c4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:44:27 +0800 Subject: [PATCH 047/103] revert kernel to 09:10 --- .../cutlass_fused_moe_kernels.cuh | 122 +++--------------- 1 file changed, 20 insertions(+), 102 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index e2f21f997d..9711f7bf49 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1082,51 +1082,6 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, } } -template -__device__ TmaWarpSpecializedGroupedGemmInput::ElementSF writeSF_v2_read(int64_t num_tokens_before_expert, int64_t expert_id, - int64_t source_token_id, int64_t token_id, int64_t elem_idx, - int64_t num_cols, - TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { - static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; - assert(input_sf != nullptr); - // TODO correct? - auto const sf_in = cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, const_cast(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - return *sf_in; -} - -template -__device__ void writeSF_v2_write(int64_t num_tokens_before_expert, int64_t expert_id, - int64_t source_token_id, int64_t token_id, int64_t elem_idx, - int64_t num_cols, - TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF value_to_write) { - static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; - - // We need to offset into the scaling factors for just this expert - auto act_sf_expert = - act_sf_flat + getOffsetActivationSF( - expert_id, num_tokens_before_expert, num_cols, - (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) - ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); - - // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this - // expert - auto sf_out = - cvt_quant_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, - std::nullopt /* numRows */, num_cols / VecSize, act_sf_expert, - QuantizationSFLayout::SWIZZLED_128x4); - if (sf_out) { - *sf_out = value_to_write; - } -} - // ====================== Compute FP8 dequant scale only =============================== __global__ void computeFP8DequantScaleKernel(float const** alpha_scale_ptr_array, int64_t const num_experts_per_node, @@ -1502,8 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -// constexpr static int EXPAND_THREADS_PER_BLOCK = 256; -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template ( + in_vec, global_scale_val, num_tokens_before_expert, expert, permuted_row, elem_index, + padded_hidden_size, fc1_act_sf_flat, + is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + static_assert(sizeof(res) == sizeof(*dest_row_ptr), + "Quantized value must be the same size as the output"); + dest_row_ptr[elem_index] = res; + } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - - constexpr int BUF_SIZE = 7; - static_assert(ceilDiv(num_elems_in_col, stride) == BUF_SIZE); - DataElem data_buf[BUF_SIZE]; - TmaWarpSpecializedGroupedGemmInput::ElementSF sf_buf[BUF_SIZE]; - -#pragma unroll - for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { - data_buf[idx] = source_row_ptr[elem_index]; - sf_buf[idx] = writeSF_v2_read(num_tokens_before_expert, expert, source_row, + writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); - } - -#pragma unroll - for (int elem_index = start_offset, idx = 0; elem_index < num_elems_in_col; elem_index += stride, ++idx) { - writeSF_v2_write(num_tokens_before_expert, expert, source_row, - permuted_row, elem_index, padded_hidden_size, - fc1_act_sf_flat, sf_buf[idx]); - dest_row_ptr[elem_index] = data_buf[idx]; - } + dest_row_ptr[elem_index] = in_vec; + } } -// #pragma unroll -// for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { -// auto in_vec = source_row_ptr[elem_index]; -// if constexpr (need_nvfp4_quant || need_mxfp8_quant) { -// auto res = quantizePackedFPXValue( -// in_vec, global_scale_val, num_tokens_before_expert, expert, permuted_row, elem_index, -// padded_hidden_size, fc1_act_sf_flat, -// is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 -// : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); -// static_assert(sizeof(res) == sizeof(*dest_row_ptr), -// "Quantized value must be the same size as the output"); -// dest_row_ptr[elem_index] = res; -// } else { -// assert(act_scale_idx == 0 && -// "Cannot use per-expert act scale for pre-quantized activations"); -// writeSF(num_tokens_before_expert, expert, source_row, -// permuted_row, elem_index, padded_hidden_size, -// fc1_act_sf_flat, input_sf); -// dest_row_ptr[elem_index] = in_vec; -// } -// } - // Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan // values in the padded SF atom Use VecSize per thread since we are just writing out zeros so // every thread can process a whole vector @@ -1779,7 +1697,7 @@ void expandInputRowsKernelLauncher( static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens)); int64_t const threads = EXPAND_THREADS_PER_BLOCK; auto func = [&]() { From ab8898781ce37b8988100505193e62339893ed6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:47:01 +0800 Subject: [PATCH 048/103] chore bench --- benchmarks/bench_cutlass_fused_moe.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index b3814f3aa9..850ec118ba 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -220,13 +220,14 @@ def f(): ) from flashinfer.testing.utils import bench_kineto - ts = bench_kineto( - f, - ("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), - suppress_kineto_output=False, - num_tests=100, - ) - print(f"Kineto output: ts_ms={['%.3f' % (t * 1000) for t in ts]}") + for _ in range(5): + ts = bench_kineto( + f, + ("expandInputRowsKernel", "doActivationKernel", "finalizeMoeRoutingKernel"), + suppress_kineto_output=False, + num_tests=100, + ) + print(f"Kineto output: ts_ms={['%.3f' % (t * 1000) for t in ts]}") if __name__ == "__main__": From 704c1c5b9bf3e202fca6ea3e62a2ff983d8ef48c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 10:51:06 +0800 Subject: [PATCH 049/103] hack: only enable "thread/=2, block*=2" --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9711f7bf49..44d028868a 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template Date: Fri, 19 Sep 2025 10:54:20 +0800 Subject: [PATCH 050/103] hack: 64thread --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 44d028868a..6fc374696d 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 64; template Date: Fri, 19 Sep 2025 10:56:09 +0800 Subject: [PATCH 051/103] Revert "hack: 64thread" This reverts commit ed6178cce390087f17641d4a6ab98bf724d82a98. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 6fc374696d..44d028868a 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 64; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template Date: Fri, 19 Sep 2025 10:57:27 +0800 Subject: [PATCH 052/103] unroll topk in unpermute kernel --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 44d028868a..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1798,8 +1798,11 @@ template Date: Fri, 19 Sep 2025 11:25:08 +0800 Subject: [PATCH 053/103] unpermute use AlignedArray --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..1e5da7a875 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1818,10 +1818,10 @@ __global__ void finalizeMoeRoutingKernel( int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; - using BiasElem = cutlass::Array; - using InputElem = cutlass::Array; - using OutputElem = cutlass::Array; - using ComputeElem = cutlass::Array; + using BiasElem = cutlass::AlignedArray; + using InputElem = cutlass::AlignedArray; + using OutputElem = cutlass::AlignedArray; + using ComputeElem = cutlass::AlignedArray; auto const* bias_v = reinterpret_cast(bias); auto const* expanded_permuted_rows_v = reinterpret_cast(expanded_permuted_rows); auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); From 8c577021771844decee34a47d5c3eee0892d9697 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 11:40:38 +0800 Subject: [PATCH 054/103] Revert "unpermute use AlignedArray" This reverts commit 4106c124f904d167ae9f5467a196ecf46b65d69d. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 1e5da7a875..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1818,10 +1818,10 @@ __global__ void finalizeMoeRoutingKernel( int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; - using BiasElem = cutlass::AlignedArray; - using InputElem = cutlass::AlignedArray; - using OutputElem = cutlass::AlignedArray; - using ComputeElem = cutlass::AlignedArray; + using BiasElem = cutlass::Array; + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; auto const* bias_v = reinterpret_cast(bias); auto const* expanded_permuted_rows_v = reinterpret_cast(expanded_permuted_rows); auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); From a6c4d38d54743cf025811f7c1d0739498655644a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 11:44:33 +0800 Subject: [PATCH 055/103] hack: manual vectorize --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..9d8e204d22 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1856,8 +1856,11 @@ __global__ void finalizeMoeRoutingKernel( auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - ComputeElem expert_result = - arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); +// ComputeElem expert_result = +// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); + InputElem input_val = reinterpret_cast(*reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index)); + ComputeElem expert_result = arrayConvert(input_val); if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); @@ -1866,8 +1869,11 @@ __global__ void finalizeMoeRoutingKernel( thread_output = thread_output + row_scale * expert_result; } - OutputElem output_elem = arrayConvert(thread_output); - reduced_row_ptr_v[elem_index] = output_elem; +// OutputElem output_elem = arrayConvert(thread_output); +// reduced_row_ptr_v[elem_index] = output_elem; + int4 output_elem = reinterpret_cast(arrayConvert(thread_output)); + static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); From 1fddc58ae61541b5c2bed2af0670a61e6dfed583 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 11:46:49 +0800 Subject: [PATCH 056/103] more manual vectorize --- .../cutlass_fused_moe_kernels.cuh | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9d8e204d22..8267abbe36 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -1800,6 +1801,10 @@ __global__ void finalizeMoeRoutingKernel( ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { +if constexpr (not (std::is_same_v and std::is_same_v)) { + printf("finalizeMoeRoutingKernel see unsupported dtype\n"); + asm("trap;"); +} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1859,8 +1864,8 @@ __global__ void finalizeMoeRoutingKernel( // ComputeElem expert_result = // arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - InputElem input_val = reinterpret_cast(*reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index)); - ComputeElem expert_result = arrayConvert(input_val); + const int4 input_val = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); @@ -1871,14 +1876,17 @@ __global__ void finalizeMoeRoutingKernel( // OutputElem output_elem = arrayConvert(thread_output); // reduced_row_ptr_v[elem_index] = output_elem; - int4 output_elem = reinterpret_cast(arrayConvert(thread_output)); + // TODO alignment issue? + __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); + int4 output_elem = *reinterpret_cast(&output_elem_original); static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); - *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } +} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From 09b5d01fb2a1cebee090419b4f933c86286b8422 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:07:52 +0800 Subject: [PATCH 057/103] Revert "more manual vectorize" This reverts commit 1fddc58ae61541b5c2bed2af0670a61e6dfed583. --- .../cutlass_fused_moe_kernels.cuh | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 8267abbe36..9d8e204d22 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -1801,10 +1800,6 @@ __global__ void finalizeMoeRoutingKernel( ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { -if constexpr (not (std::is_same_v and std::is_same_v)) { - printf("finalizeMoeRoutingKernel see unsupported dtype\n"); - asm("trap;"); -} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1864,8 +1859,8 @@ if constexpr (not (std::is_same_v and std::is_sam // ComputeElem expert_result = // arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - const int4 input_val = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); - ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); + InputElem input_val = reinterpret_cast(*reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index)); + ComputeElem expert_result = arrayConvert(input_val); if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); @@ -1876,17 +1871,14 @@ if constexpr (not (std::is_same_v and std::is_sam // OutputElem output_elem = arrayConvert(thread_output); // reduced_row_ptr_v[elem_index] = output_elem; - // TODO alignment issue? - __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); - int4 output_elem = *reinterpret_cast(&output_elem_original); + int4 output_elem = reinterpret_cast(arrayConvert(thread_output)); static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); - *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From 5f76d06845af7865cf53de8bc2db6da130b9dd92 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:07:52 +0800 Subject: [PATCH 058/103] Revert "hack: manual vectorize" This reverts commit a6c4d38d54743cf025811f7c1d0739498655644a. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9d8e204d22..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1856,11 +1856,8 @@ __global__ void finalizeMoeRoutingKernel( auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; -// ComputeElem expert_result = -// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - InputElem input_val = reinterpret_cast(*reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index)); - ComputeElem expert_result = arrayConvert(input_val); + ComputeElem expert_result = + arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); @@ -1869,11 +1866,8 @@ __global__ void finalizeMoeRoutingKernel( thread_output = thread_output + row_scale * expert_result; } -// OutputElem output_elem = arrayConvert(thread_output); -// reduced_row_ptr_v[elem_index] = output_elem; - int4 output_elem = reinterpret_cast(arrayConvert(thread_output)); - static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); - *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; + OutputElem output_elem = arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); From d7631b143e80c9b9a42f6c07e9fab928359ce05e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:10:06 +0800 Subject: [PATCH 059/103] hack: unpermute, maxnreg=32 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..8e93b0f91c 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1795,7 +1795,9 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. template -__global__ void finalizeMoeRoutingKernel( +__global__ +__maxnreg__(32) +void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, From 1c6342cd15b3709965ccc7bc8e54a655636bed3f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:11:58 +0800 Subject: [PATCH 060/103] Revert "hack: unpermute, maxnreg=32" This reverts commit d7631b143e80c9b9a42f6c07e9fab928359ce05e. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 8e93b0f91c..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1795,9 +1795,7 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. template -__global__ -__maxnreg__(32) -void finalizeMoeRoutingKernel( +__global__ void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, From 6ba3ab44bbc3ea590943c50593b0d0c5a3920a06 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:15:21 +0800 Subject: [PATCH 061/103] mv load ordering --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..1a5ef2ae57 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1839,13 +1839,14 @@ __global__ void finalizeMoeRoutingKernel( for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - if (expert_id < 0 || expert_id >= num_experts_per_node) { - continue; - } int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + if (expert_id < 0 || expert_id >= num_experts_per_node) { + continue; + } + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; From 6bcaa697d32694f5430bb4c15ac17896ff3253a4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 12:17:51 +0800 Subject: [PATCH 062/103] Revert "mv load ordering" This reverts commit 6ba3ab44bbc3ea590943c50593b0d0c5a3920a06. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 1a5ef2ae57..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1839,14 +1839,13 @@ __global__ void finalizeMoeRoutingKernel( for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - - int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; From b1bb20502b722e6123ff60f9c9df19c038b64813 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:42:37 +0800 Subject: [PATCH 063/103] make orig_cols constexpr --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..03a2c89631 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1798,10 +1798,12 @@ template ; using InputElem = cutlass::Array; From 6ca1b1d482d48649b14b6532d33921b4e4ea5a33 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:45:52 +0800 Subject: [PATCH 064/103] hack rm trap --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 03a2c89631..23eeb227af 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1801,9 +1801,9 @@ __global__ void finalizeMoeRoutingKernel( int const* token_selected_experts, int64_t const orig_cols_real_, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { constexpr int experts_per_token = 8; - if (experts_per_token != experts_per_token_real_) { asm("trap;"); } +// if (experts_per_token != experts_per_token_real_) { asm("trap;"); } constexpr int orig_cols = 7168; - if (orig_cols != orig_cols_real_) { asm("trap;"); } +// if (orig_cols != orig_cols_real_) { asm("trap;"); } int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; From fec2f4479fcd18114bbd5cee8a8ca87a68995eba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:46:38 +0800 Subject: [PATCH 065/103] Revert "hack rm trap" This reverts commit 6ca1b1d482d48649b14b6532d33921b4e4ea5a33. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 23eeb227af..03a2c89631 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1801,9 +1801,9 @@ __global__ void finalizeMoeRoutingKernel( int const* token_selected_experts, int64_t const orig_cols_real_, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { constexpr int experts_per_token = 8; -// if (experts_per_token != experts_per_token_real_) { asm("trap;"); } + if (experts_per_token != experts_per_token_real_) { asm("trap;"); } constexpr int orig_cols = 7168; -// if (orig_cols != orig_cols_real_) { asm("trap;"); } + if (orig_cols != orig_cols_real_) { asm("trap;"); } int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; From 44d390e00f5e2d669095eba066333d8920c00320 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:46:38 +0800 Subject: [PATCH 066/103] Revert "make orig_cols constexpr" This reverts commit b1bb20502b722e6123ff60f9c9df19c038b64813. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 03a2c89631..b3585b7f10 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1798,12 +1798,10 @@ template ; using InputElem = cutlass::Array; From 14d26e8167ba9a5fc9322db29b0f84ec4496b6bf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:50:40 +0800 Subject: [PATCH 067/103] cp: vectorize --- .../cutlass_fused_moe_kernels.cuh | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index b3585b7f10..8267abbe36 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -1800,6 +1801,10 @@ __global__ void finalizeMoeRoutingKernel( ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { +if constexpr (not (std::is_same_v and std::is_same_v)) { + printf("finalizeMoeRoutingKernel see unsupported dtype\n"); + asm("trap;"); +} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1856,8 +1861,11 @@ __global__ void finalizeMoeRoutingKernel( auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - ComputeElem expert_result = - arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); +// ComputeElem expert_result = +// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); + const int4 input_val = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); @@ -1866,13 +1874,19 @@ __global__ void finalizeMoeRoutingKernel( thread_output = thread_output + row_scale * expert_result; } - OutputElem output_elem = arrayConvert(thread_output); - reduced_row_ptr_v[elem_index] = output_elem; +// OutputElem output_elem = arrayConvert(thread_output); +// reduced_row_ptr_v[elem_index] = output_elem; + // TODO alignment issue? + __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); + int4 output_elem = *reinterpret_cast(&output_elem_original); + static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } +} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From 4480029efa5b84f664c8428411fe86c8a87ecf9d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:50:54 +0800 Subject: [PATCH 068/103] cp: mv load ordering --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 8267abbe36..edcfcd23da 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1844,13 +1844,14 @@ if constexpr (not (std::is_same_v and std::is_sam for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - if (expert_id < 0 || expert_id >= num_experts_per_node) { - continue; - } int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + if (expert_id < 0 || expert_id >= num_experts_per_node) { + continue; + } + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; From 80b704e024da9a2359eccd3c96b2c30969c9d51f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 13:58:52 +0800 Subject: [PATCH 069/103] hack: rm bias handling (incorrect? why is it used?) --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index edcfcd23da..af6dce119e 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1867,10 +1867,10 @@ if constexpr (not (std::is_same_v and std::is_sam static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); const int4 input_val = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); - if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; - expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); - } +// if (bias) { +// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; +// expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); +// } thread_output = thread_output + row_scale * expert_result; } From f954a1f46109ea4ee388aa5fa5b5cf1759b4ba59 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:07:05 +0800 Subject: [PATCH 070/103] hack: read 8 - compute 8, instead of 8x(read 1 compute 1) --- .../cutlass_fused_moe_kernels.cuh | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index af6dce119e..acb8c82b67 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1840,6 +1840,8 @@ if constexpr (not (std::is_same_v and std::is_sam ComputeElem thread_output; thread_output.fill(0); + int4 input_val_buf[experts_per_token]; + #pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; @@ -1857,15 +1859,24 @@ if constexpr (not (std::is_same_v and std::is_sam continue; } - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; // ComputeElem expert_result = // arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - const int4 input_val = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + } + +#pragma unroll + for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { + int64_t const k_offset = original_row * experts_per_token + k_idx; + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + + // TODO + // TODO incorrect! need to skip some cases + // TODO + int4 input_val = input_val_buf[k_idx]; ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); // if (bias) { // auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; From 9603cf747cab87baab4e9cb712ffdfd9d0289ce4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:14:04 +0800 Subject: [PATCH 071/103] naive handle enable_input_buf --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index acb8c82b67..62679f1003 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1841,6 +1841,7 @@ if constexpr (not (std::is_same_v and std::is_sam thread_output.fill(0); int4 input_val_buf[experts_per_token]; + bool enable_input_buf[experts_per_token]; #pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1851,11 +1852,13 @@ if constexpr (not (std::is_same_v and std::is_sam int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; if (expert_id < 0 || expert_id >= num_experts_per_node) { + enable_input_buf[k_idx] = false; continue; } int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { + enable_input_buf[k_idx] = false; continue; } @@ -1866,16 +1869,16 @@ if constexpr (not (std::is_same_v and std::is_sam // arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + enable_input_buf[k_idx] = true; } #pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { + if (!enable_input_buf[k_idx]) continue; + int64_t const k_offset = original_row * experts_per_token + k_idx; float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - // TODO - // TODO incorrect! need to skip some cases - // TODO int4 input_val = input_val_buf[k_idx]; ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); // if (bias) { From 2e52a49096273b8b32b9be6484573fb242ed4429 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:19:22 +0800 Subject: [PATCH 072/103] enable_input_buf use bitwise op --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 62679f1003..9b4439df17 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1841,7 +1841,7 @@ if constexpr (not (std::is_same_v and std::is_sam thread_output.fill(0); int4 input_val_buf[experts_per_token]; - bool enable_input_buf[experts_per_token]; + uint32_t enable_input_buf = 0; #pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { @@ -1852,13 +1852,11 @@ if constexpr (not (std::is_same_v and std::is_sam int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; if (expert_id < 0 || expert_id >= num_experts_per_node) { - enable_input_buf[k_idx] = false; continue; } int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { - enable_input_buf[k_idx] = false; continue; } @@ -1869,12 +1867,12 @@ if constexpr (not (std::is_same_v and std::is_sam // arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); - enable_input_buf[k_idx] = true; + enable_input_buf |= 1 << k_idx; } #pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { - if (!enable_input_buf[k_idx]) continue; + if (not (enable_input_buf & (1 << k_idx))) continue; int64_t const k_offset = original_row * experts_per_token + k_idx; float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; From 1c75b9bb2b5adef97a7df36bf56a649ac041bf22 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:23:47 +0800 Subject: [PATCH 073/103] hack: unpermute, maxnreg=64 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9b4439df17..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1796,7 +1796,9 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. template -__global__ void finalizeMoeRoutingKernel( +__global__ +__maxnreg__(64) +void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, From 4e406240ac2f007e2bf9e1fb72569a03c1e08b27 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:33:24 +0800 Subject: [PATCH 074/103] doActivationKernel reg=32 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..a6b8e706a8 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2208,7 +2208,9 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ template -__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, +__global__ +__maxnreg__(32) +void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size_real_, From ea5a59405288e119f316f8078ee2fa972713c238 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:38:11 +0800 Subject: [PATCH 075/103] Revert "doActivationKernel reg=32" This reverts commit 4e406240ac2f007e2bf9e1fb72569a03c1e08b27. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index a6b8e706a8..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2208,9 +2208,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ template -__global__ -__maxnreg__(32) -void doActivationKernel(T* output, GemmOutputType const* gemm_result, +__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size_real_, From a5d4189702bf6f74b75a6fd48655b847572c4518 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:38:28 +0800 Subject: [PATCH 076/103] hack: acti blocks 8->6 --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..7a56c586a7 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 6, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From b52391ac4292f70863d06b47b0428b4973071409 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:44:14 +0800 Subject: [PATCH 077/103] Revert "hack: acti blocks 8->6" This reverts commit a5d4189702bf6f74b75a6fd48655b847572c4518. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7a56c586a7..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 6, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From 7853d15b3bcce80a02d84c0369e9465c8ee89c3f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:45:24 +0800 Subject: [PATCH 078/103] hack: acti - infinite num blocks --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..20119002fc 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 16384, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From f670fa4d9403f071b744f74962bcbbc7468d69b5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 14:56:49 +0800 Subject: [PATCH 079/103] hack: acti - mid num blocks --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 20119002fc..79983da790 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 16384, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 128, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From 50378ba1418a33e8c39f4ea6eff25ce3f26f638c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 15:00:40 +0800 Subject: [PATCH 080/103] Revert "hack: acti - mid num blocks" This reverts commit f670fa4d9403f071b744f74962bcbbc7468d69b5. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 79983da790..20119002fc 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 128, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 16384, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From 9d1456a5a6d51b5b6ce3168e23fb769c5f6d1430 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 15:00:40 +0800 Subject: [PATCH 081/103] Revert "hack: acti - infinite num blocks" This reverts commit 7853d15b3bcce80a02d84c0369e9465c8ee89c3f. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 20119002fc..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2439,7 +2439,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 16384, std::max(expanded_num_tokens, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; auto fn = [&]() { From ec89f0c76964c482755de25b58229524f910dfef Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 15:32:58 +0800 Subject: [PATCH 082/103] temp rm all --- .../cutlass_fused_moe_kernels.cuh | 134 ++++-------------- 1 file changed, 24 insertions(+), 110 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..d94bc69b23 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -866,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken( // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts template -__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { @@ -882,49 +881,6 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices return target_location + 1; } -template -__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - constexpr int ARR_LENGTH_CONST = 128; - if (arr_length != ARR_LENGTH_CONST) { - asm("trap;"); - } - - constexpr unsigned full_mask = 0xffffffffu; - constexpr int WARP_SZ = 32; - const int lane_id = threadIdx.x & (WARP_SZ - 1); - - int local_count = 0; -#pragma unroll - for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { - const int idx = lane_id + k * WARP_SZ; - T v = sorted_indices[idx]; - local_count += (v < target) ? 1 : 0; - } - -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_count += __shfl_down_sync(full_mask, local_count, offset); - } - int total = __shfl_sync(full_mask, local_count, 0); - - return (int64_t)total; -} - -template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { -// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); - -// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); -// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); -// if (out_v1 != out_v2) { -// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); -// asm("trap;"); -// } -// return out_v1; -} - template using sizeof_bits = cutlass::sizeof_bits< typename cutlass_kernels::TllmToCutlassTypeAdapter>::type>; @@ -1458,7 +1414,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template -__global__ -__maxnreg__(64) -void finalizeMoeRoutingKernel( +__global__ void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, - int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, + int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { -if constexpr (not (std::is_same_v and std::is_same_v)) { - printf("finalizeMoeRoutingKernel see unsupported dtype\n"); - asm("trap;"); -} else { - constexpr int experts_per_token = 8; - if (experts_per_token != experts_per_token_real_) { asm("trap;"); } - int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; @@ -1841,67 +1784,43 @@ if constexpr (not (std::is_same_v and std::is_sam for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); - - int4 input_val_buf[experts_per_token]; - uint32_t enable_input_buf = 0; - -#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - - int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; } + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; -// ComputeElem expert_result = -// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); - enable_input_buf |= 1 << k_idx; - } - -#pragma unroll - for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { - if (not (enable_input_buf & (1 << k_idx))) continue; - - int64_t const k_offset = original_row * experts_per_token + k_idx; - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - - int4 input_val = input_val_buf[k_idx]; - ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); -// if (bias) { -// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; -// expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); -// } + ComputeElem expert_result = + arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + if (bias) { + auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); + } thread_output = thread_output + row_scale * expert_result; } -// OutputElem output_elem = arrayConvert(thread_output); -// reduced_row_ptr_v[elem_index] = output_elem; - // TODO alignment issue? - __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); - int4 output_elem = *reinterpret_cast(&output_elem_original); - static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); - *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; + OutputElem output_elem = arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip @@ -2211,13 +2130,10 @@ template && @@ -2302,9 +2218,9 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto output_vec = reinterpret_cast(safe_inc_ptr(output, output_offset)); auto bias_ptr_vec = reinterpret_cast(bias_ptr + bias_offset); int64_t const start_offset = tid; - constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK; + int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0); - constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; + int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; @@ -2312,8 +2228,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fn.alpha = gate_alpha; fn.beta = gate_beta; fn.limit = gate_limit; - -#pragma unroll for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert(gemm_result_vec[elem_index + gated_off_vec]); From 783120b09d6f953b54db0366748f644ec628e96d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 15:34:49 +0800 Subject: [PATCH 083/103] change test --- tests/test_trtllm_cutlass_fused_moe.py | 1653 ++++++++++++------------ 1 file changed, 833 insertions(+), 820 deletions(-) diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index 5f9f04dc63..c1cbbc8970 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -211,144 +211,156 @@ def compute_with_experts( # Test configurations +# BATCH_SIZES = [ +# 1, +# ] +# HIDDEN_SIZES = [ +# 128, +# ] +# NUM_EXPERTS = [2] +# TOP_K_VALUES = [2] +# INTERMEDIATE_SIZES = [ +# 128, +# ] +# EP_NUM_EXPERTS = [8] +# EP_TOP_K = [2] + +# NOTE MODIFIED BATCH_SIZES = [ 1, ] HIDDEN_SIZES = [ - 128, + 7168, ] -NUM_EXPERTS = [2] -TOP_K_VALUES = [2] +NUM_EXPERTS = [128] +TOP_K_VALUES = [8] INTERMEDIATE_SIZES = [ - 128, + 2048, ] -EP_NUM_EXPERTS = [8] -EP_TOP_K = [2] - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): - # Skip invalid configurations - if top_k > num_experts: - pytest.skip( - f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" - ) - - torch.manual_seed(42) - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5 - router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda() - w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 - ).cuda() - / 5 - ) - w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 - ).cuda() - / 5 - ) - - routing_weights, selected_experts = compute_routing(router_logits, top_k) - ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - flash_output = torch.empty_like(ref_output) - flash_output = fused_moe.cutlass_fused_moe( - x, - selected_experts.to(torch.int), - routing_weights, - w31_weight, - w2_weight, - flash_output.dtype, - output=flash_output, - quant_scales=None, - ) - - torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)]) -def test_moe_fp8( - batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype -): - # Skip invalid configurations - if top_k > num_experts: - pytest.skip( - f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" - ) - - torch.manual_seed(42) - input_shape = (batch_size, hidden_size) - w31_shape = (num_experts, 2 * intermediate_size, hidden_size) - w2_shape = (num_experts, hidden_size, intermediate_size) - x = cast_to_representable(gen_tensor(input_shape, otype)) - router_logits = gen_tensor((batch_size, num_experts), otype) - - # Create weight tensors - w31_weight = gen_tensor(w31_shape, otype, wtype) - w2_weight = gen_tensor(w2_shape, otype, wtype) - w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda() - w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda() - - w31_dequantized = gen_tensor(w31_shape, otype) - w2_dequantized = gen_tensor(w2_shape, otype) - for expert_id in range(num_experts): - w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) - w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) - - w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) - w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) - - w31_weight.data[expert_id].copy_(w31_quant) - w2_weight.data[expert_id].copy_(w2_quant) - w31_scales.data[expert_id].copy_(s31) - w2_scales.data[expert_id].copy_(s2) - w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) - w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) - - routing_weights, selected_experts = compute_routing(router_logits, top_k) - ref_output = compute_with_experts( - num_experts, - x, - w31_dequantized, - w2_dequantized, - selected_experts, - routing_weights, - ) - flash_output = torch.empty_like(ref_output) - # For fp8, the hidden_state expects quantized. - _, w1_scales = torch.chunk(w31_scales, 2, dim=-1) - x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) - hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda() - quant_scales = [ - torch.squeeze(w1_scales * hidden_states_scale).float(), - torch.tensor(1.0).cuda(), - torch.squeeze(1.0 * w2_scales).float(), - hidden_states_scale, - ] - - _ = fused_moe.cutlass_fused_moe( - x_quant, - selected_experts.to(torch.int), - routing_weights, - w31_weight, - w2_weight, - otype, - quant_scales=quant_scales, - output=flash_output, - ) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", TOP_K_VALUES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): +# # Skip invalid configurations +# if top_k > num_experts: +# pytest.skip( +# f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" +# ) +# +# torch.manual_seed(42) +# x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5 +# router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda() +# w31_weight = ( +# torch.randn( +# num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 +# ).cuda() +# / 5 +# ) +# w2_weight = ( +# torch.randn( +# num_experts, hidden_size, intermediate_size, dtype=torch.float16 +# ).cuda() +# / 5 +# ) +# +# routing_weights, selected_experts = compute_routing(router_logits, top_k) +# ref_output = compute_with_experts( +# num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights +# ) +# flash_output = torch.empty_like(ref_output) +# flash_output = fused_moe.cutlass_fused_moe( +# x, +# selected_experts.to(torch.int), +# routing_weights, +# w31_weight, +# w2_weight, +# flash_output.dtype, +# output=flash_output, +# quant_scales=None, +# ) +# +# torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2) + + +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", TOP_K_VALUES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# @pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)]) +# def test_moe_fp8( +# batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype +# ): +# # Skip invalid configurations +# if top_k > num_experts: +# pytest.skip( +# f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" +# ) +# +# torch.manual_seed(42) +# input_shape = (batch_size, hidden_size) +# w31_shape = (num_experts, 2 * intermediate_size, hidden_size) +# w2_shape = (num_experts, hidden_size, intermediate_size) +# x = cast_to_representable(gen_tensor(input_shape, otype)) +# router_logits = gen_tensor((batch_size, num_experts), otype) +# +# # Create weight tensors +# w31_weight = gen_tensor(w31_shape, otype, wtype) +# w2_weight = gen_tensor(w2_shape, otype, wtype) +# w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda() +# w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda() +# +# w31_dequantized = gen_tensor(w31_shape, otype) +# w2_dequantized = gen_tensor(w2_shape, otype) +# for expert_id in range(num_experts): +# w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) +# w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) +# +# w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) +# w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) +# +# w31_weight.data[expert_id].copy_(w31_quant) +# w2_weight.data[expert_id].copy_(w2_quant) +# w31_scales.data[expert_id].copy_(s31) +# w2_scales.data[expert_id].copy_(s2) +# w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) +# w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) +# +# routing_weights, selected_experts = compute_routing(router_logits, top_k) +# ref_output = compute_with_experts( +# num_experts, +# x, +# w31_dequantized, +# w2_dequantized, +# selected_experts, +# routing_weights, +# ) +# flash_output = torch.empty_like(ref_output) +# # For fp8, the hidden_state expects quantized. +# _, w1_scales = torch.chunk(w31_scales, 2, dim=-1) +# x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) +# hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda() +# quant_scales = [ +# torch.squeeze(w1_scales * hidden_states_scale).float(), +# torch.tensor(1.0).cuda(), +# torch.squeeze(1.0 * w2_scales).float(), +# hidden_states_scale, +# ] +# +# _ = fused_moe.cutlass_fused_moe( +# x_quant, +# selected_experts.to(torch.int), +# routing_weights, +# w31_weight, +# w2_weight, +# otype, +# quant_scales=quant_scales, +# output=flash_output, +# ) +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @@ -360,7 +372,8 @@ def test_moe_fp8( "otype, wtype", [(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)], ) -@pytest.mark.parametrize("quantized_input", [False, True]) +# @pytest.mark.parametrize("quantized_input", [False, True]) +@pytest.mark.parametrize("quantized_input", [True]) @pytest.mark.skipif( torch.cuda.get_device_capability()[0] not in [10, 11, 12], reason="NVFP4 is only supported on SM100, SM110 and SM120", @@ -511,327 +524,327 @@ def test_moe_nvfp4( ) torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1) - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) -@pytest.mark.parametrize("top_k", EP_TOP_K) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -def test_moe_expert_parallel( - batch_size, hidden_size, num_experts, top_k, intermediate_size -): - """ - Test expert parallelism with X GPUs and Y experts. - Each GPU handles one expert and results are reduced. - - Args: - batch_size: Batch size for the input - hidden_size: Hidden dimension size - num_experts: Number of experts (must be 2 for this test) - top_k: Number of experts to route to per token - intermediate_size: Intermediate dimension size - activation: Activation function type - """ - # This test is specifically for 2 GPUs and 2 experts - # GPU 0 (ep_rank=0) handles expert 0 - # GPU 1 (ep_rank=1) handles expert 1 - ep_size = num_experts // 2 - torch.manual_seed(42) - - # Create input tensors - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() - - # Create weight tensors - each GPU will have one expert - w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 - ).cuda() - / 10 - ) - w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 - ).cuda() - / 10 - ) - - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] - ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - - outputs = [] - flash_output = torch.zeros_like(ref_output) - for ep_rank in range(ep_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Compute expert start and end positions for this rank - experts_per_rank = ( - num_experts // ep_size - ) # 2 GPUs, so each gets half the experts - expert_start = ep_rank * experts_per_rank - expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts - - w31_weight_local = w31_weight[ - expert_start:expert_end, : - ] # Get only the experts for this rank - w2_weight_local = w2_weight[ - expert_start:expert_end, : - ] # Get only the experts for this rank - - _ = fused_moe.cutlass_fused_moe( - x.contiguous(), - selected_experts.to(torch.int), - routing_weights, - w31_weight_local.contiguous(), - w2_weight_local.contiguous(), - x.dtype, - ep_size=ep_size, - ep_rank=ep_rank, - quant_scales=None, - output=out_hidden_states_local, - ) - outputs.append(out_hidden_states_local) - - # Reduce results from all GPUs - for ep_rank in range(ep_size): - flash_output += outputs[ep_rank] # [batch_size, num_experts] - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) - - -TP_SIZES = [2, 4] - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("tp_size", TP_SIZES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -def test_moe_tensor_parallel( - batch_size, hidden_size, num_experts, tp_size, intermediate_size -): - """ - Test tensor parallelism with: - - w31 sharded along second dimension (non-contracting) - - w2 sharded along third dimension (contracting) - - All-reduce to sum partial results - - Args: - batch_size: Batch size for the input - hidden_size: Hidden dimension size - num_experts: Number of experts - top_k: Number of experts to route to per token - intermediate_size: Intermediate dimension size - activation: Activation function type - """ - # Set random seed for reproducibility - torch.manual_seed(42) - top_k = 2 - # Create input tensors - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() - - # Create weight tensors - w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 - ).cuda() - / 10 - ) - w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 - ).cuda() - / 10 - ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] - ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no parallelism) - ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - - # Simulate tensor parallelism on # TP GPUs - outputs = [] - for tp_rank in range(tp_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Shard w31 along second dimension (intermediate_size) - # First split w31 into w3 and w1 - w3_weight, w1_weight = torch.chunk( - w31_weight, 2, dim=1 - ) # [num_experts, intermediate_size, hidden_size] each - - # Shard w3 and w1 separately - w3_shard_size = intermediate_size // tp_size - w3_start = tp_rank * w3_shard_size - w3_end = w3_start + w3_shard_size - w3_weight_local = w3_weight[:, w3_start:w3_end, :] - - w1_shard_size = intermediate_size // tp_size - w1_start = tp_rank * w1_shard_size - w1_end = w1_start + w1_shard_size - w1_weight_local = w1_weight[:, w1_start:w1_end, :] - - # Stack the sharded weights back together - w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) - - # Shard w2 along third dimension (intermediate_size) - w2_shard_size = intermediate_size // tp_size - w2_start = tp_rank * w2_shard_size - w2_end = w2_start + w2_shard_size - w2_weight_local = w2_weight[:, :, w2_start:w2_end] - - _ = fused_moe.cutlass_fused_moe( - x.contiguous(), - selected_experts.to(torch.int), - routing_weights, - w31_weight_local.contiguous(), - w2_weight_local.contiguous(), - x.dtype, - tp_size=tp_size, - tp_rank=tp_rank, - quant_scales=None, - output=out_hidden_states_local, - ) - outputs.append(out_hidden_states_local) - - # All-reduce to sum partial results from all GPUs - flash_output = sum(outputs) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) -@pytest.mark.parametrize("top_k", EP_TOP_K) -@pytest.mark.parametrize("tp_size", TP_SIZES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -def test_moe_tensor_expert_parallel( - batch_size, hidden_size, num_experts, top_k, tp_size, intermediate_size -): - """ - Test combined tensor parallelism and expert parallelism: - - Expert parallelism: Distribute experts across GPUs - - Tensor parallelism: For each expert's weights: - - w31 sharded along second dimension (non-contracting) - - w2 sharded along third dimension (contracting) - - All-reduce to sum partial results - - Args: - batch_size: Batch size for the input - hidden_size: Hidden dimension size - num_experts: Number of experts - tp_size: Number of GPUs for tensor parallelism - intermediate_size: Intermediate dimension size - """ - torch.manual_seed(42) - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() - w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 - ).cuda() - / 10 - ) - w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 - ).cuda() - / 10 - ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] - ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no parallelism) - ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - - # Simulate combined parallelism - ep_size = num_experts // 2 # Number of GPUs for expert parallelism - outputs = [] - - # For each expert parallel rank - for ep_rank in range(ep_size): - # Get experts for this rank - experts_per_rank = num_experts // ep_size - expert_start = ep_rank * experts_per_rank - expert_end = expert_start + experts_per_rank - - # Get expert weights for this rank - w31_weight_ep = w31_weight[ - expert_start:expert_end, : - ] # [experts_per_rank, 2*intermediate_size, hidden_size] - w2_weight_ep = w2_weight[ - expert_start:expert_end, : - ] # [experts_per_rank, hidden_size, intermediate_size] - - # For each tensor parallel rank - for tp_rank in range(tp_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Split w31 into w3 and w1 - w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1) - - # Shard w3 and w1 separately - w3_shard_size = intermediate_size // tp_size - w3_start = tp_rank * w3_shard_size - w3_end = w3_start + w3_shard_size - w3_weight_local = w3_weight[:, w3_start:w3_end, :] - - w1_shard_size = intermediate_size // tp_size - w1_start = tp_rank * w1_shard_size - w1_end = w1_start + w1_shard_size - w1_weight_local = w1_weight[:, w1_start:w1_end, :] - - # Stack the sharded weights back together - w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) - - # Shard w2 along third dimension - w2_shard_size = intermediate_size // tp_size - w2_start = tp_rank * w2_shard_size - w2_end = w2_start + w2_shard_size - w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end] - - # Call flashinfer implementation with both parallelisms - out_hidden_states_local = fused_moe.cutlass_fused_moe( - x.contiguous(), - selected_experts.to(torch.int), - routing_weights, - w31_weight_local.contiguous(), - w2_weight_local.contiguous(), - x.dtype, - tp_size=tp_size, - tp_rank=tp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - quant_scales=None, - ) - outputs.append(out_hidden_states_local[0]) - - # All-reduce to sum partial results from all GPUs - flash_output = sum(outputs) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) +# +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", EP_TOP_K) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# def test_moe_expert_parallel( +# batch_size, hidden_size, num_experts, top_k, intermediate_size +# ): +# """ +# Test expert parallelism with X GPUs and Y experts. +# Each GPU handles one expert and results are reduced. +# +# Args: +# batch_size: Batch size for the input +# hidden_size: Hidden dimension size +# num_experts: Number of experts (must be 2 for this test) +# top_k: Number of experts to route to per token +# intermediate_size: Intermediate dimension size +# activation: Activation function type +# """ +# # This test is specifically for 2 GPUs and 2 experts +# # GPU 0 (ep_rank=0) handles expert 0 +# # GPU 1 (ep_rank=1) handles expert 1 +# ep_size = num_experts // 2 +# torch.manual_seed(42) +# +# # Create input tensors +# x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() +# +# # Create weight tensors - each GPU will have one expert +# w31_weight = ( +# torch.randn( +# num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# w2_weight = ( +# torch.randn( +# num_experts, hidden_size, intermediate_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# +# selected_experts = torch.stack( +# [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] +# ).cuda() +# +# routing_weights = torch.randn((batch_size, top_k)).cuda() +# routing_weights = F.softmax(routing_weights, dim=1) +# ref_output = compute_with_experts( +# num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights +# ) +# +# outputs = [] +# flash_output = torch.zeros_like(ref_output) +# for ep_rank in range(ep_size): +# # Create output tensor for this GPU +# out_hidden_states_local = torch.zeros_like(x) +# +# # Compute expert start and end positions for this rank +# experts_per_rank = ( +# num_experts // ep_size +# ) # 2 GPUs, so each gets half the experts +# expert_start = ep_rank * experts_per_rank +# expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts +# +# w31_weight_local = w31_weight[ +# expert_start:expert_end, : +# ] # Get only the experts for this rank +# w2_weight_local = w2_weight[ +# expert_start:expert_end, : +# ] # Get only the experts for this rank +# +# _ = fused_moe.cutlass_fused_moe( +# x.contiguous(), +# selected_experts.to(torch.int), +# routing_weights, +# w31_weight_local.contiguous(), +# w2_weight_local.contiguous(), +# x.dtype, +# ep_size=ep_size, +# ep_rank=ep_rank, +# quant_scales=None, +# output=out_hidden_states_local, +# ) +# outputs.append(out_hidden_states_local) +# +# # Reduce results from all GPUs +# for ep_rank in range(ep_size): +# flash_output += outputs[ep_rank] # [batch_size, num_experts] +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +# +# +# TP_SIZES = [2, 4] +# +# +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("tp_size", TP_SIZES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# def test_moe_tensor_parallel( +# batch_size, hidden_size, num_experts, tp_size, intermediate_size +# ): +# """ +# Test tensor parallelism with: +# - w31 sharded along second dimension (non-contracting) +# - w2 sharded along third dimension (contracting) +# - All-reduce to sum partial results +# +# Args: +# batch_size: Batch size for the input +# hidden_size: Hidden dimension size +# num_experts: Number of experts +# top_k: Number of experts to route to per token +# intermediate_size: Intermediate dimension size +# activation: Activation function type +# """ +# # Set random seed for reproducibility +# torch.manual_seed(42) +# top_k = 2 +# # Create input tensors +# x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() +# +# # Create weight tensors +# w31_weight = ( +# torch.randn( +# num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# w2_weight = ( +# torch.randn( +# num_experts, hidden_size, intermediate_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# +# # Generate unique random expert indices for each token +# selected_experts = torch.stack( +# [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] +# ).cuda() +# +# routing_weights = torch.randn((batch_size, top_k)).cuda() +# routing_weights = F.softmax(routing_weights, dim=1) +# +# # Run reference implementation (no parallelism) +# ref_output = compute_with_experts( +# num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights +# ) +# +# # Simulate tensor parallelism on # TP GPUs +# outputs = [] +# for tp_rank in range(tp_size): +# # Create output tensor for this GPU +# out_hidden_states_local = torch.zeros_like(x) +# +# # Shard w31 along second dimension (intermediate_size) +# # First split w31 into w3 and w1 +# w3_weight, w1_weight = torch.chunk( +# w31_weight, 2, dim=1 +# ) # [num_experts, intermediate_size, hidden_size] each +# +# # Shard w3 and w1 separately +# w3_shard_size = intermediate_size // tp_size +# w3_start = tp_rank * w3_shard_size +# w3_end = w3_start + w3_shard_size +# w3_weight_local = w3_weight[:, w3_start:w3_end, :] +# +# w1_shard_size = intermediate_size // tp_size +# w1_start = tp_rank * w1_shard_size +# w1_end = w1_start + w1_shard_size +# w1_weight_local = w1_weight[:, w1_start:w1_end, :] +# +# # Stack the sharded weights back together +# w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) +# +# # Shard w2 along third dimension (intermediate_size) +# w2_shard_size = intermediate_size // tp_size +# w2_start = tp_rank * w2_shard_size +# w2_end = w2_start + w2_shard_size +# w2_weight_local = w2_weight[:, :, w2_start:w2_end] +# +# _ = fused_moe.cutlass_fused_moe( +# x.contiguous(), +# selected_experts.to(torch.int), +# routing_weights, +# w31_weight_local.contiguous(), +# w2_weight_local.contiguous(), +# x.dtype, +# tp_size=tp_size, +# tp_rank=tp_rank, +# quant_scales=None, +# output=out_hidden_states_local, +# ) +# outputs.append(out_hidden_states_local) +# +# # All-reduce to sum partial results from all GPUs +# flash_output = sum(outputs) +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) +# +# +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", EP_TOP_K) +# @pytest.mark.parametrize("tp_size", TP_SIZES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# def test_moe_tensor_expert_parallel( +# batch_size, hidden_size, num_experts, top_k, tp_size, intermediate_size +# ): +# """ +# Test combined tensor parallelism and expert parallelism: +# - Expert parallelism: Distribute experts across GPUs +# - Tensor parallelism: For each expert's weights: +# - w31 sharded along second dimension (non-contracting) +# - w2 sharded along third dimension (contracting) +# - All-reduce to sum partial results +# +# Args: +# batch_size: Batch size for the input +# hidden_size: Hidden dimension size +# num_experts: Number of experts +# tp_size: Number of GPUs for tensor parallelism +# intermediate_size: Intermediate dimension size +# """ +# torch.manual_seed(42) +# x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() +# w31_weight = ( +# torch.randn( +# num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# w2_weight = ( +# torch.randn( +# num_experts, hidden_size, intermediate_size, dtype=torch.float16 +# ).cuda() +# / 10 +# ) +# +# # Generate unique random expert indices for each token +# selected_experts = torch.stack( +# [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] +# ).cuda() +# +# routing_weights = torch.randn((batch_size, top_k)).cuda() +# routing_weights = F.softmax(routing_weights, dim=1) +# +# # Run reference implementation (no parallelism) +# ref_output = compute_with_experts( +# num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights +# ) +# +# # Simulate combined parallelism +# ep_size = num_experts // 2 # Number of GPUs for expert parallelism +# outputs = [] +# +# # For each expert parallel rank +# for ep_rank in range(ep_size): +# # Get experts for this rank +# experts_per_rank = num_experts // ep_size +# expert_start = ep_rank * experts_per_rank +# expert_end = expert_start + experts_per_rank +# +# # Get expert weights for this rank +# w31_weight_ep = w31_weight[ +# expert_start:expert_end, : +# ] # [experts_per_rank, 2*intermediate_size, hidden_size] +# w2_weight_ep = w2_weight[ +# expert_start:expert_end, : +# ] # [experts_per_rank, hidden_size, intermediate_size] +# +# # For each tensor parallel rank +# for tp_rank in range(tp_size): +# # Create output tensor for this GPU +# out_hidden_states_local = torch.zeros_like(x) +# +# # Split w31 into w3 and w1 +# w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1) +# +# # Shard w3 and w1 separately +# w3_shard_size = intermediate_size // tp_size +# w3_start = tp_rank * w3_shard_size +# w3_end = w3_start + w3_shard_size +# w3_weight_local = w3_weight[:, w3_start:w3_end, :] +# +# w1_shard_size = intermediate_size // tp_size +# w1_start = tp_rank * w1_shard_size +# w1_end = w1_start + w1_shard_size +# w1_weight_local = w1_weight[:, w1_start:w1_end, :] +# +# # Stack the sharded weights back together +# w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) +# +# # Shard w2 along third dimension +# w2_shard_size = intermediate_size // tp_size +# w2_start = tp_rank * w2_shard_size +# w2_end = w2_start + w2_shard_size +# w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end] +# +# # Call flashinfer implementation with both parallelisms +# out_hidden_states_local = fused_moe.cutlass_fused_moe( +# x.contiguous(), +# selected_experts.to(torch.int), +# routing_weights, +# w31_weight_local.contiguous(), +# w2_weight_local.contiguous(), +# x.dtype, +# tp_size=tp_size, +# tp_rank=tp_rank, +# ep_size=ep_size, +# ep_rank=ep_rank, +# quant_scales=None, +# ) +# outputs.append(out_hidden_states_local[0]) +# +# # All-reduce to sum partial results from all GPUs +# flash_output = sum(outputs) +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) def ceil_div(a: int, b: int) -> int: @@ -933,124 +946,124 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: return x_dequant.view(original_shape) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11, 12], - reason="FP8 block scaling is only supported on SM100, SM110 and SM120", -) -def test_moe_fp8_block_scaling( - batch_size, hidden_size, num_experts, top_k, intermediate_size -): - """ - Test MoE with FP8 block scaling (Deepseek style): - - Activation: 128x1 blocks - - Weights: 128x128 blocks - - Each block has its own scaling factor - - Args: - batch_size: Batch size for the input - hidden_size: Hidden dimension size - num_experts: Number of experts - top_k: Number of experts to route to per token - intermediate_size: Intermediate dimension size - Only support bf16 for hidden_states - """ - torch.manual_seed(42) - otype = torch.bfloat16 - - x = torch.randn(batch_size, hidden_size, dtype=otype).cuda() - - w31_weight = ( - torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=otype).cuda() - / 10 - ) - w2_weight = ( - torch.randn(num_experts, hidden_size, intermediate_size, dtype=otype).cuda() - / 10 - ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] - ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no quantization) - _ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - - # Quantize input and weights - x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128) - - w31_dequant = torch.empty_like(w31_weight) - w2_dequant = torch.empty_like(w2_weight) - w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn) - w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn) - w31_scales = torch.randn( - num_experts, - ceil_div(2 * intermediate_size, 128), - ceil_div(hidden_size, 128), - dtype=torch.float32, - ).cuda() - w2_scales = torch.randn( - num_experts, - ceil_div(hidden_size, 128), - ceil_div(intermediate_size, 128), - dtype=torch.float32, - ).cuda() - - for expert_id in range(num_experts): - w31, w31_s = per_block_cast_to_fp8(w31_weight[expert_id, :]) - w2, w2_s = per_block_cast_to_fp8(w2_weight[expert_id, :]) - w31_quant.data[expert_id].copy_(w31) - w31_scales.data[expert_id].copy_(w31_s) - w2_quant.data[expert_id].copy_(w2) - w2_scales.data[expert_id].copy_(w2_s) - # Dequantize for verificationa - x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape) - w31_dequant = dequantize_block( - w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape - ) - w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape) - - # Run reference implementation with dequantized tensors - _ref_output = compute_with_experts( - num_experts, - x_dequant, - w31_dequant, - w2_dequant, - selected_experts, - routing_weights, - ) - quant_scales = [ - w31_scales, # .view(-1), # W31 scales - w2_scales, # .view(-1), # W2 scales - ] - - # Call flashinfer implementation with block scaling and expect NotImplementedError - with pytest.raises( - NotImplementedError, - match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell", - ): - _ = fused_moe.cutlass_fused_moe( - x.contiguous(), - selected_experts.to(torch.int), - routing_weights, - w31_quant.contiguous(), - w2_quant.contiguous(), - otype, - tp_size=1, - tp_rank=0, - use_deepseek_fp8_block_scale=True, - quant_scales=quant_scales, - ) +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", TOP_K_VALUES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# @pytest.mark.skipif( +# torch.cuda.get_device_capability()[0] not in [10, 11, 12], +# reason="FP8 block scaling is only supported on SM100, SM110 and SM120", +# ) +# def test_moe_fp8_block_scaling( +# batch_size, hidden_size, num_experts, top_k, intermediate_size +# ): +# """ +# Test MoE with FP8 block scaling (Deepseek style): +# - Activation: 128x1 blocks +# - Weights: 128x128 blocks +# - Each block has its own scaling factor +# +# Args: +# batch_size: Batch size for the input +# hidden_size: Hidden dimension size +# num_experts: Number of experts +# top_k: Number of experts to route to per token +# intermediate_size: Intermediate dimension size +# Only support bf16 for hidden_states +# """ +# torch.manual_seed(42) +# otype = torch.bfloat16 +# +# x = torch.randn(batch_size, hidden_size, dtype=otype).cuda() +# +# w31_weight = ( +# torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=otype).cuda() +# / 10 +# ) +# w2_weight = ( +# torch.randn(num_experts, hidden_size, intermediate_size, dtype=otype).cuda() +# / 10 +# ) +# +# # Generate unique random expert indices for each token +# selected_experts = torch.stack( +# [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] +# ).cuda() +# +# routing_weights = torch.randn((batch_size, top_k)).cuda() +# routing_weights = F.softmax(routing_weights, dim=1) +# +# # Run reference implementation (no quantization) +# _ref_output = compute_with_experts( +# num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights +# ) +# +# # Quantize input and weights +# x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128) +# +# w31_dequant = torch.empty_like(w31_weight) +# w2_dequant = torch.empty_like(w2_weight) +# w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn) +# w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn) +# w31_scales = torch.randn( +# num_experts, +# ceil_div(2 * intermediate_size, 128), +# ceil_div(hidden_size, 128), +# dtype=torch.float32, +# ).cuda() +# w2_scales = torch.randn( +# num_experts, +# ceil_div(hidden_size, 128), +# ceil_div(intermediate_size, 128), +# dtype=torch.float32, +# ).cuda() +# +# for expert_id in range(num_experts): +# w31, w31_s = per_block_cast_to_fp8(w31_weight[expert_id, :]) +# w2, w2_s = per_block_cast_to_fp8(w2_weight[expert_id, :]) +# w31_quant.data[expert_id].copy_(w31) +# w31_scales.data[expert_id].copy_(w31_s) +# w2_quant.data[expert_id].copy_(w2) +# w2_scales.data[expert_id].copy_(w2_s) +# # Dequantize for verificationa +# x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape) +# w31_dequant = dequantize_block( +# w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape +# ) +# w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape) +# +# # Run reference implementation with dequantized tensors +# _ref_output = compute_with_experts( +# num_experts, +# x_dequant, +# w31_dequant, +# w2_dequant, +# selected_experts, +# routing_weights, +# ) +# quant_scales = [ +# w31_scales, # .view(-1), # W31 scales +# w2_scales, # .view(-1), # W2 scales +# ] +# +# # Call flashinfer implementation with block scaling and expect NotImplementedError +# with pytest.raises( +# NotImplementedError, +# match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell", +# ): +# _ = fused_moe.cutlass_fused_moe( +# x.contiguous(), +# selected_experts.to(torch.int), +# routing_weights, +# w31_quant.contiguous(), +# w2_quant.contiguous(), +# otype, +# tp_size=1, +# tp_rank=0, +# use_deepseek_fp8_block_scale=True, +# quant_scales=quant_scales, +# ) def quant_mxfp4_batches(a, num_experts): @@ -1083,137 +1096,137 @@ def dequant_mxfp4_batches( ) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize( - ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] -) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11, 12], - reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120", -) -def test_moe_mxfp8_mxfp4( - batch_size, - hidden_size, - num_experts, - top_k, - intermediate_size, - otype, - alpha, - beta, - limit, -): - """ - Test MoE with MXFP8 activations and MXFP4 weights. - Uses mxfp8_quantize for activations and fp4_quantize for weights. - """ - # Skip invalid configurations - if top_k > num_experts: - pytest.skip( - f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" - ) - - torch.manual_seed(42) - e = num_experts - m = batch_size - n = intermediate_size - k = hidden_size - - x = torch.randn(m, k, dtype=otype).cuda() - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 - - mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) - - mxfp4_w1, mxfp4_w1_scale = quant_mxfp4_batches(w1, e) - mxfp4_w2, mxfp4_w2_scale = quant_mxfp4_batches(w2, e) - - router_logits = torch.randn(m, e, dtype=otype).cuda() - routing_weights, selected_experts = compute_routing(router_logits, top_k) - - fake_input_scale = torch.ones(e, device=x.device) - - quant_scales = [ - mxfp4_w1_scale.view(torch.int32), - fake_input_scale, - mxfp4_w2_scale.view(torch.int32), - fake_input_scale, - ] - - flash_output = torch.zeros_like(x) - - if alpha is not None and limit is not None and beta is not None: - alpha_t = torch.ones(e, device=x.device) * alpha - limit_t = torch.ones(e, device=x.device) * limit - beta_t = torch.ones(e, device=x.device) * beta - else: - alpha_t = None - limit_t = None - beta_t = None - - # Call cutlass_fused_moe with MXFP8 activations and MXFP4 weights - _ = fused_moe.cutlass_fused_moe( - mxfp8_x, - selected_experts.to(torch.int), - routing_weights, - mxfp4_w1.contiguous().view(torch.long), - mxfp4_w2.contiguous().view(torch.long), - otype, - swiglu_alpha=alpha_t, - swiglu_limit=limit_t, - swiglu_beta=beta_t, - quant_scales=quant_scales, - input_sf=mxfp8_x_sf, - use_mxfp8_act_scaling=True, - output=flash_output, - ) - - dq_mxfp8_x = ( - mxfp8_dequantize_host( - mxfp8_x.cpu().view(torch.uint8), - mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), - True, - ) - .cuda() - .to(otype) - ) - - dq_mfxp4_w1 = ( - dequant_mxfp4_batches( - mxfp4_w1.cpu().view(torch.uint8), - mxfp4_w1_scale.cpu().view(torch.uint8).reshape(-1), - ) - .cuda() - .to(otype) - ) - - dq_mfxp4_w2 = ( - dequant_mxfp4_batches( - mxfp4_w2.cpu().view(torch.uint8), - mxfp4_w2_scale.cpu().view(torch.uint8).reshape(-1), - ) - .cuda() - .to(otype) - ) - - # Use original weights for reference computation - ref_output = compute_with_experts( - e, - dq_mxfp8_x, - dq_mfxp4_w1, - dq_mfxp4_w2, - selected_experts, - routing_weights, - alpha, - beta, - limit, - ) - - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", TOP_K_VALUES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# @pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize( +# ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] +# ) +# @pytest.mark.skipif( +# torch.cuda.get_device_capability()[0] not in [10, 11, 12], +# reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120", +# ) +# def test_moe_mxfp8_mxfp4( +# batch_size, +# hidden_size, +# num_experts, +# top_k, +# intermediate_size, +# otype, +# alpha, +# beta, +# limit, +# ): +# """ +# Test MoE with MXFP8 activations and MXFP4 weights. +# Uses mxfp8_quantize for activations and fp4_quantize for weights. +# """ +# # Skip invalid configurations +# if top_k > num_experts: +# pytest.skip( +# f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" +# ) +# +# torch.manual_seed(42) +# e = num_experts +# m = batch_size +# n = intermediate_size +# k = hidden_size +# +# x = torch.randn(m, k, dtype=otype).cuda() +# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 +# w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 +# +# mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) +# +# mxfp4_w1, mxfp4_w1_scale = quant_mxfp4_batches(w1, e) +# mxfp4_w2, mxfp4_w2_scale = quant_mxfp4_batches(w2, e) +# +# router_logits = torch.randn(m, e, dtype=otype).cuda() +# routing_weights, selected_experts = compute_routing(router_logits, top_k) +# +# fake_input_scale = torch.ones(e, device=x.device) +# +# quant_scales = [ +# mxfp4_w1_scale.view(torch.int32), +# fake_input_scale, +# mxfp4_w2_scale.view(torch.int32), +# fake_input_scale, +# ] +# +# flash_output = torch.zeros_like(x) +# +# if alpha is not None and limit is not None and beta is not None: +# alpha_t = torch.ones(e, device=x.device) * alpha +# limit_t = torch.ones(e, device=x.device) * limit +# beta_t = torch.ones(e, device=x.device) * beta +# else: +# alpha_t = None +# limit_t = None +# beta_t = None +# +# # Call cutlass_fused_moe with MXFP8 activations and MXFP4 weights +# _ = fused_moe.cutlass_fused_moe( +# mxfp8_x, +# selected_experts.to(torch.int), +# routing_weights, +# mxfp4_w1.contiguous().view(torch.long), +# mxfp4_w2.contiguous().view(torch.long), +# otype, +# swiglu_alpha=alpha_t, +# swiglu_limit=limit_t, +# swiglu_beta=beta_t, +# quant_scales=quant_scales, +# input_sf=mxfp8_x_sf, +# use_mxfp8_act_scaling=True, +# output=flash_output, +# ) +# +# dq_mxfp8_x = ( +# mxfp8_dequantize_host( +# mxfp8_x.cpu().view(torch.uint8), +# mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), +# True, +# ) +# .cuda() +# .to(otype) +# ) +# +# dq_mfxp4_w1 = ( +# dequant_mxfp4_batches( +# mxfp4_w1.cpu().view(torch.uint8), +# mxfp4_w1_scale.cpu().view(torch.uint8).reshape(-1), +# ) +# .cuda() +# .to(otype) +# ) +# +# dq_mfxp4_w2 = ( +# dequant_mxfp4_batches( +# mxfp4_w2.cpu().view(torch.uint8), +# mxfp4_w2_scale.cpu().view(torch.uint8).reshape(-1), +# ) +# .cuda() +# .to(otype) +# ) +# +# # Use original weights for reference computation +# ref_output = compute_with_experts( +# e, +# dq_mxfp8_x, +# dq_mfxp4_w1, +# dq_mfxp4_w2, +# selected_experts, +# routing_weights, +# alpha, +# beta, +# limit, +# ) +# +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) def dequant_mxfp4_batches_host( @@ -1228,125 +1241,125 @@ def dequant_mxfp4_batches_host( ) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize( - ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] -) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] != 9, - reason="BF16xMXFP4 is only supported on SM90", -) -def test_moe_bf16_mxfp4( - batch_size, - hidden_size, - num_experts, - top_k, - intermediate_size, - alpha, - beta, - limit, -): - """ - Test MoE with bf16 activations and MXFP4 weights. - Uses bf16 for activations and fp4_quantize for weights. - """ - # Skip invalid configurations - if top_k > num_experts: - pytest.skip( - f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" - ) - - torch.manual_seed(42) - e = num_experts - m = batch_size - n = intermediate_size - k = hidden_size - - x = torch.randn(m, k, dtype=torch.bfloat16).cuda() - w1 = torch.randint(0, 256, (e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w2 = torch.randint(0, 256, (e, k, n // 2), device="cuda", dtype=torch.uint8) - - w1_scale = torch.randint( - 118, 123, (e, 2 * n, k // 32), device="cuda", dtype=torch.uint8 - ) - w2_scale = torch.randint( - 118, 123, (e, k, n // 32), device="cuda", dtype=torch.uint8 - ) - - router_logits = torch.randn(m, e, dtype=torch.bfloat16).cuda() - routing_weights, selected_experts = compute_routing(router_logits, top_k) - - flash_output = torch.zeros_like(x) - - if alpha is not None and limit is not None and beta is not None: - alpha_t = torch.ones(e, device=x.device) * alpha - limit_t = torch.ones(e, device=x.device) * limit - beta_t = torch.ones(e, device=x.device) * beta - else: - alpha_t = None - limit_t = None - beta_t = None - - pad_size = hidden_size - x.shape[1] - x_pad = torch.nn.functional.pad(x, (0, pad_size)) - - quant_scales = [ - w1_scale.view(torch.int32), - w2_scale.view(torch.int32), - ] - - # Call cutlass_fused_moe with BF16 activations and MXFP4 weights - _ = fused_moe.cutlass_fused_moe( - x_pad, - selected_experts.to(torch.int), - routing_weights, - w1.contiguous().view(torch.uint8), - w2.contiguous().view(torch.uint8), - torch.bfloat16, - swiglu_alpha=alpha_t, - swiglu_limit=limit_t, - swiglu_beta=beta_t, - quant_scales=quant_scales, - use_w4_group_scaling=True, - output=flash_output, - ) - - dq_mfxp4_w1 = ( - dequant_mxfp4_batches_host( - w1.cpu(), - w1_scale.cpu(), - ) - .cuda() - .to(torch.bfloat16) - ) - - dq_mfxp4_w2 = ( - dequant_mxfp4_batches_host( - w2.cpu(), - w2_scale.cpu(), - ) - .cuda() - .to(torch.bfloat16) - ) - - # Use original weights for reference computation - ref_output = compute_with_experts( - e, - x, - dq_mfxp4_w1, - dq_mfxp4_w2, - selected_experts, - routing_weights, - alpha, - beta, - limit, - ) - - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +# @pytest.mark.parametrize("batch_size", BATCH_SIZES) +# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +# @pytest.mark.parametrize("num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("top_k", TOP_K_VALUES) +# @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +# @pytest.mark.parametrize( +# ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] +# ) +# @pytest.mark.skipif( +# torch.cuda.get_device_capability()[0] != 9, +# reason="BF16xMXFP4 is only supported on SM90", +# ) +# def test_moe_bf16_mxfp4( +# batch_size, +# hidden_size, +# num_experts, +# top_k, +# intermediate_size, +# alpha, +# beta, +# limit, +# ): +# """ +# Test MoE with bf16 activations and MXFP4 weights. +# Uses bf16 for activations and fp4_quantize for weights. +# """ +# # Skip invalid configurations +# if top_k > num_experts: +# pytest.skip( +# f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" +# ) +# +# torch.manual_seed(42) +# e = num_experts +# m = batch_size +# n = intermediate_size +# k = hidden_size +# +# x = torch.randn(m, k, dtype=torch.bfloat16).cuda() +# w1 = torch.randint(0, 256, (e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) +# w2 = torch.randint(0, 256, (e, k, n // 2), device="cuda", dtype=torch.uint8) +# +# w1_scale = torch.randint( +# 118, 123, (e, 2 * n, k // 32), device="cuda", dtype=torch.uint8 +# ) +# w2_scale = torch.randint( +# 118, 123, (e, k, n // 32), device="cuda", dtype=torch.uint8 +# ) +# +# router_logits = torch.randn(m, e, dtype=torch.bfloat16).cuda() +# routing_weights, selected_experts = compute_routing(router_logits, top_k) +# +# flash_output = torch.zeros_like(x) +# +# if alpha is not None and limit is not None and beta is not None: +# alpha_t = torch.ones(e, device=x.device) * alpha +# limit_t = torch.ones(e, device=x.device) * limit +# beta_t = torch.ones(e, device=x.device) * beta +# else: +# alpha_t = None +# limit_t = None +# beta_t = None +# +# pad_size = hidden_size - x.shape[1] +# x_pad = torch.nn.functional.pad(x, (0, pad_size)) +# +# quant_scales = [ +# w1_scale.view(torch.int32), +# w2_scale.view(torch.int32), +# ] +# +# # Call cutlass_fused_moe with BF16 activations and MXFP4 weights +# _ = fused_moe.cutlass_fused_moe( +# x_pad, +# selected_experts.to(torch.int), +# routing_weights, +# w1.contiguous().view(torch.uint8), +# w2.contiguous().view(torch.uint8), +# torch.bfloat16, +# swiglu_alpha=alpha_t, +# swiglu_limit=limit_t, +# swiglu_beta=beta_t, +# quant_scales=quant_scales, +# use_w4_group_scaling=True, +# output=flash_output, +# ) +# +# dq_mfxp4_w1 = ( +# dequant_mxfp4_batches_host( +# w1.cpu(), +# w1_scale.cpu(), +# ) +# .cuda() +# .to(torch.bfloat16) +# ) +# +# dq_mfxp4_w2 = ( +# dequant_mxfp4_batches_host( +# w2.cpu(), +# w2_scale.cpu(), +# ) +# .cuda() +# .to(torch.bfloat16) +# ) +# +# # Use original weights for reference computation +# ref_output = compute_with_experts( +# e, +# x, +# dq_mfxp4_w1, +# dq_mfxp4_w2, +# selected_experts, +# routing_weights, +# alpha, +# beta, +# limit, +# ) +# +# torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) if __name__ == "__main__": From 7b0f47155ae9cf90f40b2bba888096823dcedd8d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 15:42:07 +0800 Subject: [PATCH 084/103] Revert "temp rm all" This reverts commit ec89f0c76964c482755de25b58229524f910dfef. --- .../cutlass_fused_moe_kernels.cuh | 134 ++++++++++++++---- 1 file changed, 110 insertions(+), 24 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d94bc69b23..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken( // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, +__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, int64_t const arr_length, T const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { @@ -881,6 +882,49 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, return target_location + 1; } +template +__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { + constexpr int ARR_LENGTH_CONST = 128; + if (arr_length != ARR_LENGTH_CONST) { + asm("trap;"); + } + + constexpr unsigned full_mask = 0xffffffffu; + constexpr int WARP_SZ = 32; + const int lane_id = threadIdx.x & (WARP_SZ - 1); + + int local_count = 0; +#pragma unroll + for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { + const int idx = lane_id + k * WARP_SZ; + T v = sorted_indices[idx]; + local_count += (v < target) ? 1 : 0; + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(full_mask, local_count, offset); + } + int total = __shfl_sync(full_mask, local_count, 0); + + return (int64_t)total; +} + +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { +// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); + + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + +// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); +// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); +// if (out_v1 != out_v2) { +// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); +// asm("trap;"); +// } +// return out_v1; +} + template using sizeof_bits = cutlass::sizeof_bits< typename cutlass_kernels::TllmToCutlassTypeAdapter>::type>; @@ -1414,7 +1458,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template -__global__ void finalizeMoeRoutingKernel( +__global__ +__maxnreg__(64) +void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, - int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token, + int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { +if constexpr (not (std::is_same_v and std::is_same_v)) { + printf("finalizeMoeRoutingKernel see unsupported dtype\n"); + asm("trap;"); +} else { + constexpr int experts_per_token = 8; + if (experts_per_token != experts_per_token_real_) { asm("trap;"); } + int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; @@ -1784,43 +1841,67 @@ __global__ void finalizeMoeRoutingKernel( for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); + + int4 input_val_buf[experts_per_token]; + uint32_t enable_input_buf = 0; + +#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - if (expert_id < 0 || expert_id >= num_experts_per_node) { - continue; - } int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + if (expert_id < 0 || expert_id >= num_experts_per_node) { + continue; + } + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; } - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - ComputeElem expert_result = - arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; - expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); - } +// ComputeElem expert_result = +// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); + input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + enable_input_buf |= 1 << k_idx; + } + +#pragma unroll + for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { + if (not (enable_input_buf & (1 << k_idx))) continue; + + int64_t const k_offset = original_row * experts_per_token + k_idx; + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + + int4 input_val = input_val_buf[k_idx]; + ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); +// if (bias) { +// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; +// expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); +// } thread_output = thread_output + row_scale * expert_result; } - OutputElem output_elem = arrayConvert(thread_output); - reduced_row_ptr_v[elem_index] = output_elem; +// OutputElem output_elem = arrayConvert(thread_output); +// reduced_row_ptr_v[elem_index] = output_elem; + // TODO alignment issue? + __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); + int4 output_elem = *reinterpret_cast(&output_elem_original); + static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } +} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip @@ -2130,10 +2211,13 @@ template && @@ -2218,9 +2302,9 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto output_vec = reinterpret_cast(safe_inc_ptr(output, output_offset)); auto bias_ptr_vec = reinterpret_cast(bias_ptr + bias_offset); int64_t const start_offset = tid; - int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; + constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK; assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0); - int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; + constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; @@ -2228,6 +2312,8 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fn.alpha = gate_alpha; fn.beta = gate_beta; fn.limit = gate_limit; + +#pragma unroll for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert(gemm_result_vec[elem_index + gated_off_vec]); From 5873bb37ff7433fde6aaf7696173dc1d0d0bfdce Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:26:31 +0800 Subject: [PATCH 085/103] ARR_LENGTH_CONST --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..e355e272b2 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices return target_location + 1; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - constexpr int ARR_LENGTH_CONST = 128; if (arr_length != ARR_LENGTH_CONST) { asm("trap;"); } @@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices return (int64_t)total; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); @@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template + bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void expandInputRowsKernel( InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, @@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel( if constexpr (is_nvfp4 || is_mxfp8) { static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); - int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)permuted_row + 1) - 1; @@ -2207,7 +2206,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= template + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, @@ -2270,7 +2269,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = - findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - + findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; From bb7a97a471075fabf107ddd5fa8c0330e5a895f3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:28:26 +0800 Subject: [PATCH 086/103] Revert "ARR_LENGTH_CONST" This reverts commit 5873bb37ff7433fde6aaf7696173dc1d0d0bfdce. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index e355e272b2..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -882,8 +882,9 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices return target_location + 1; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { + constexpr int ARR_LENGTH_CONST = 128; if (arr_length != ARR_LENGTH_CONST) { asm("trap;"); } @@ -909,11 +910,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices return (int64_t)total; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); @@ -1461,7 +1462,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template + bool PRE_QUANT_AWQ> __global__ void expandInputRowsKernel( InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, @@ -1556,7 +1557,7 @@ __global__ void expandInputRowsKernel( if constexpr (is_nvfp4 || is_mxfp8) { static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); - int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)permuted_row + 1) - 1; @@ -2206,7 +2207,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= template + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, @@ -2269,7 +2270,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = - findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - + findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; From 8c719e6b25001428b96b208059381d54da803cca Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:36:03 +0800 Subject: [PATCH 087/103] hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..af9ccbf06b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -884,10 +884,10 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices template __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - constexpr int ARR_LENGTH_CONST = 128; - if (arr_length != ARR_LENGTH_CONST) { - asm("trap;"); - } + // constexpr int ARR_LENGTH_CONST = 128; + // if (arr_length != ARR_LENGTH_CONST) { + // asm("trap;"); + // } constexpr unsigned full_mask = 0xffffffffu; constexpr int WARP_SZ = 32; @@ -895,7 +895,7 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices int local_count = 0; #pragma unroll - for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { + for (int k = 0; k < arr_length / WARP_SZ; ++k) { const int idx = lane_id + k * WARP_SZ; T v = sorted_indices[idx]; local_count += (v < target) ? 1 : 0; From fece8640fa265a68ecd26657260ddb714f886a6c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:38:11 +0800 Subject: [PATCH 088/103] hack: unroll(4) --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index af9ccbf06b..c7d87ec424 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -894,7 +894,7 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices const int lane_id = threadIdx.x & (WARP_SZ - 1); int local_count = 0; -#pragma unroll +#pragma unroll(4) for (int k = 0; k < arr_length / WARP_SZ; ++k) { const int idx = lane_id + k * WARP_SZ; T v = sorted_indices[idx]; From b9bb8c7927d8621a8d9a4e9767fc7deb7fdbd896 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:40:20 +0800 Subject: [PATCH 089/103] Revert "hack: unroll(4)" This reverts commit fece8640fa265a68ecd26657260ddb714f886a6c. --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index c7d87ec424..af9ccbf06b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -894,7 +894,7 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices const int lane_id = threadIdx.x & (WARP_SZ - 1); int local_count = 0; -#pragma unroll(4) +#pragma unroll for (int k = 0; k < arr_length / WARP_SZ; ++k) { const int idx = lane_id + k * WARP_SZ; T v = sorted_indices[idx]; From 7e6c8764172b303d8d2dde28d0f36ab23dcca5b6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:40:20 +0800 Subject: [PATCH 090/103] Revert "hack: findTotalEltsLessThanTarget_v2 support arbitrary arr len" This reverts commit 8c719e6b25001428b96b208059381d54da803cca. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index af9ccbf06b..d5543fb122 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -884,10 +884,10 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices template __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - // constexpr int ARR_LENGTH_CONST = 128; - // if (arr_length != ARR_LENGTH_CONST) { - // asm("trap;"); - // } + constexpr int ARR_LENGTH_CONST = 128; + if (arr_length != ARR_LENGTH_CONST) { + asm("trap;"); + } constexpr unsigned full_mask = 0xffffffffu; constexpr int WARP_SZ = 32; @@ -895,7 +895,7 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices int local_count = 0; #pragma unroll - for (int k = 0; k < arr_length / WARP_SZ; ++k) { + for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { const int idx = lane_id + k * WARP_SZ; T v = sorted_indices[idx]; local_count += (v < target) ? 1 : 0; From 06003aa9abb22702838fcb2f92ce2fcfbcc4fa87 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:46:37 +0800 Subject: [PATCH 091/103] hack NUM_EXPERTS_PER_NODE_CONST --- .../cutlass_fused_moe_kernels.cuh | 112 ++++++++++++------ 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d5543fb122..38239c08c1 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices return target_location + 1; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - constexpr int ARR_LENGTH_CONST = 128; if (arr_length != ARR_LENGTH_CONST) { asm("trap;"); } @@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices return (int64_t)total; } -template +template __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); @@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template + bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void expandInputRowsKernel( InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, @@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel( if constexpr (is_nvfp4 || is_mxfp8) { static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); - int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)permuted_row + 1) - 1; @@ -1735,9 +1734,20 @@ void expandInputRowsKernelLauncher( TLLM_CHECK_WITH_INFO(quant_params.fp4.fc1.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4"); TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); - return &expandInputRowsKernel; + false, NUM_EXPERTS_PER_NODE_CONST>; + } + if (num_experts_per_node == 64) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; + return &expandInputRowsKernel; + } + printf("unsupported num_experts_per_node\n"); + exit(1); } else #endif { @@ -2159,7 +2169,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, float gate_bias = 0.0f; float gate_limit = std::numeric_limits::infinity(); if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) { - int expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int expert = findTotalEltsLessThanTarget<128>(expert_first_token_offset, num_experts_per_node, (int64_t)token + 1) - 1; gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f; @@ -2207,7 +2217,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= template + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, @@ -2270,7 +2280,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = - findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - + findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; @@ -2444,30 +2454,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 auto fn = [&]() { auto fn = [&](auto block_scaling_type) { - auto fn_list = std::array{ - &doActivationKernel, - decltype(block_scaling_type)::value>, // Gelu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Relu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Silu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Swiglu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Geglu - &doActivationKernel, // SwigluBias - &doActivationKernel, - decltype(block_scaling_type)::value> // Identity - - }; - return fn_list[static_cast(activation_type.activation_type)]; + if (num_experts_per_node == 128) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; + auto fn_list = std::array{ + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu + &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity + + }; + return fn_list[static_cast(activation_type.activation_type)]; + } + if (num_experts_per_node == 64) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; + auto fn_list = std::array{ + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu + &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity + + }; + return fn_list[static_cast(activation_type.activation_type)]; + } + printf("unsupported num_experts_per_node\n"); + exit(1); }; auto NVFP4 = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, From afa5a617fd30d8a169b2dcced84bbbd1cde94745 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:51:32 +0800 Subject: [PATCH 092/103] temp: 4gpu bench --- benchmarks/bench_cutlass_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 850ec118ba..0257e2a55e 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -29,7 +29,7 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -num_ranks = 2 +num_ranks = 4 test_configs = [ # { From 1583eb018d9b70c34a6520d73c431e26fcdd3d14 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 16:53:51 +0800 Subject: [PATCH 093/103] fix --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 38239c08c1..937987d141 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2482,7 +2482,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 return fn_list[static_cast(activation_type.activation_type)]; } if (num_experts_per_node == 64) { - constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; + constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; auto fn_list = std::array{ &doActivationKernel, From 3704820e19f6ba5fa00da1af9962be9a43e07bde Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 20:17:15 +0800 Subject: [PATCH 094/103] temp rm all --- .../cutlass_fused_moe_kernels.cuh | 238 ++++-------------- 1 file changed, 55 insertions(+), 183 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 937987d141..d94bc69b23 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -866,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken( // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts template -__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { @@ -882,48 +881,6 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices return target_location + 1; } -template -__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { - if (arr_length != ARR_LENGTH_CONST) { - asm("trap;"); - } - - constexpr unsigned full_mask = 0xffffffffu; - constexpr int WARP_SZ = 32; - const int lane_id = threadIdx.x & (WARP_SZ - 1); - - int local_count = 0; -#pragma unroll - for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { - const int idx = lane_id + k * WARP_SZ; - T v = sorted_indices[idx]; - local_count += (v < target) ? 1 : 0; - } - -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - local_count += __shfl_down_sync(full_mask, local_count, offset); - } - int total = __shfl_sync(full_mask, local_count, 0); - - return (int64_t)total; -} - -template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { -// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); - - return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); - -// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); -// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); -// if (out_v1 != out_v2) { -// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); -// asm("trap;"); -// } -// return out_v1; -} - template using sizeof_bits = cutlass::sizeof_bits< typename cutlass_kernels::TllmToCutlassTypeAdapter>::type>; @@ -1457,23 +1414,20 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template + bool PRE_QUANT_AWQ> __global__ void expandInputRowsKernel( InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_, + int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { - constexpr int hidden_size = 7168; - if (hidden_size != hidden_size_real_) { asm("trap;"); } - static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); @@ -1549,14 +1503,14 @@ __global__ void expandInputRowsKernel( permuted_row * hidden_size / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; - constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; - constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD; + int64_t const stride = EXPAND_THREADS_PER_BLOCK; + int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); if constexpr (is_nvfp4 || is_mxfp8) { static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); - int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)permuted_row + 1) - 1; @@ -1565,7 +1519,6 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; -#pragma unroll for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; if constexpr (need_nvfp4_quant || need_mxfp8_quant) { @@ -1697,7 +1650,7 @@ void expandInputRowsKernelLauncher( static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = std::min(smCount * 16, std::max(num_rows * k, num_padding_tokens)); + int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens)); int64_t const threads = EXPAND_THREADS_PER_BLOCK; auto func = [&]() { @@ -1734,20 +1687,9 @@ void expandInputRowsKernelLauncher( TLLM_CHECK_WITH_INFO(quant_params.fp4.fc1.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4"); TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); - if (num_experts_per_node == 128) { - constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; - return &expandInputRowsKernel; - } - if (num_experts_per_node == 64) { - constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; - return &expandInputRowsKernel; - } - printf("unsupported num_experts_per_node\n"); - exit(1); + false>; } else #endif { @@ -1806,20 +1748,11 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. template -__global__ -__maxnreg__(64) -void finalizeMoeRoutingKernel( +__global__ void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, - int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, + int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { -if constexpr (not (std::is_same_v and std::is_same_v)) { - printf("finalizeMoeRoutingKernel see unsupported dtype\n"); - asm("trap;"); -} else { - constexpr int experts_per_token = 8; - if (experts_per_token != experts_per_token_real_) { asm("trap;"); } - int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; @@ -1851,67 +1784,43 @@ if constexpr (not (std::is_same_v and std::is_sam for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); - - int4 input_val_buf[experts_per_token]; - uint32_t enable_input_buf = 0; - -#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - - int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; } + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; -// ComputeElem expert_result = -// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); - input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); - enable_input_buf |= 1 << k_idx; - } - -#pragma unroll - for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { - if (not (enable_input_buf & (1 << k_idx))) continue; - - int64_t const k_offset = original_row * experts_per_token + k_idx; - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - - int4 input_val = input_val_buf[k_idx]; - ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); -// if (bias) { -// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; -// expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); -// } + ComputeElem expert_result = + arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + if (bias) { + auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; + expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); + } thread_output = thread_output + row_scale * expert_result; } -// OutputElem output_elem = arrayConvert(thread_output); -// reduced_row_ptr_v[elem_index] = output_elem; - // TODO alignment issue? - __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); - int4 output_elem = *reinterpret_cast(&output_elem_original); - static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); - *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; + OutputElem output_elem = arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip @@ -2169,7 +2078,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, float gate_bias = 0.0f; float gate_limit = std::numeric_limits::infinity(); if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) { - int expert = findTotalEltsLessThanTarget<128>(expert_first_token_offset, num_experts_per_node, + int expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)token + 1) - 1; gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f; @@ -2217,17 +2126,14 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= template + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, - int num_experts_per_node, int64_t inter_size_real_, + int num_experts_per_node, int64_t inter_size, float const* fc2_act_global_scale, bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params) { - constexpr int inter_size = 2048; - if (inter_size != inter_size_real_) { asm("trap;"); } - #ifdef ENABLE_FP4 constexpr bool IsNVFP4 = std::is_same_v && @@ -2280,7 +2186,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = - findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - + findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; @@ -2312,9 +2218,9 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto output_vec = reinterpret_cast(safe_inc_ptr(output, output_offset)); auto bias_ptr_vec = reinterpret_cast(bias_ptr + bias_offset); int64_t const start_offset = tid; - constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK; + int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0); - constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; + int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; @@ -2322,8 +2228,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fn.alpha = gate_alpha; fn.beta = gate_beta; fn.limit = gate_limit; - -#pragma unroll for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert(gemm_result_vec[elem_index + gated_off_vec]); @@ -2454,62 +2358,30 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 auto fn = [&]() { auto fn = [&](auto block_scaling_type) { - if (num_experts_per_node == 128) { - constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; - auto fn_list = std::array{ - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu - &doActivationKernel, // SwigluBias - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity - - }; - return fn_list[static_cast(activation_type.activation_type)]; - } - if (num_experts_per_node == 64) { - constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; - auto fn_list = std::array{ - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu - &doActivationKernel, // SwigluBias - &doActivationKernel, - decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity - - }; - return fn_list[static_cast(activation_type.activation_type)]; - } - printf("unsupported num_experts_per_node\n"); - exit(1); + auto fn_list = std::array{ + &doActivationKernel, + decltype(block_scaling_type)::value>, // Gelu + &doActivationKernel, + decltype(block_scaling_type)::value>, // Relu + &doActivationKernel, + decltype(block_scaling_type)::value>, // Silu + &doActivationKernel, + decltype(block_scaling_type)::value>, // Swiglu + &doActivationKernel, + decltype(block_scaling_type)::value>, // Geglu + &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value> // Identity + + }; + return fn_list[static_cast(activation_type.activation_type)]; }; auto NVFP4 = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, From 3a74536ee71a20472a2d99b99875d098891aef18 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 20:18:36 +0800 Subject: [PATCH 095/103] partial cp --- .../cutlass_fused_moe_kernels.cuh | 180 ++++++++++++++---- 1 file changed, 140 insertions(+), 40 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index d94bc69b23..76edcfdd0b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/workspace.h" @@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken( // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, +__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, int64_t const arr_length, T const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { @@ -881,6 +882,48 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, return target_location + 1; } +template +__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { + if (arr_length != ARR_LENGTH_CONST) { + asm("trap;"); + } + + constexpr unsigned full_mask = 0xffffffffu; + constexpr int WARP_SZ = 32; + const int lane_id = threadIdx.x & (WARP_SZ - 1); + + int local_count = 0; +#pragma unroll + for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { + const int idx = lane_id + k * WARP_SZ; + T v = sorted_indices[idx]; + local_count += (v < target) ? 1 : 0; + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(full_mask, local_count, offset); + } + int total = __shfl_sync(full_mask, local_count, 0); + + return (int64_t)total; +} + +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { +// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); + + return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); + +// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); +// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); +// if (out_v1 != out_v2) { +// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); +// asm("trap;"); +// } +// return out_v1; +} + template using sizeof_bits = cutlass::sizeof_bits< typename cutlass_kernels::TllmToCutlassTypeAdapter>::type>; @@ -1418,16 +1461,19 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template + bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void expandInputRowsKernel( InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, + int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { + constexpr int hidden_size = 7168; + if (hidden_size != hidden_size_real_) { asm("trap;"); } + static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); @@ -1503,14 +1549,14 @@ __global__ void expandInputRowsKernel( permuted_row * hidden_size / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; - int64_t const stride = EXPAND_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; + constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK; + constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD; assert(hidden_size % ELEM_PER_THREAD == 0); assert(hidden_size % VecSize == 0); if constexpr (is_nvfp4 || is_mxfp8) { static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); - int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int64_t expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t)permuted_row + 1) - 1; @@ -1519,6 +1565,7 @@ __global__ void expandInputRowsKernel( float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; +#pragma unroll for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; if constexpr (need_nvfp4_quant || need_mxfp8_quant) { @@ -1687,9 +1734,20 @@ void expandInputRowsKernelLauncher( TLLM_CHECK_WITH_INFO(quant_params.fp4.fc1.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4"); TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); - return &expandInputRowsKernel; + false, NUM_EXPERTS_PER_NODE_CONST>; + } + if (num_experts_per_node == 64) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; + return &expandInputRowsKernel; + } + printf("unsupported num_experts_per_node\n"); + exit(1); } else #endif { @@ -1748,11 +1806,16 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. template -__global__ void finalizeMoeRoutingKernel( +__global__ +__maxnreg__(64) +void finalizeMoeRoutingKernel( GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, - int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token, + int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { + constexpr int experts_per_token = 8; + if (experts_per_token != experts_per_token_real_) { asm("trap;"); } + int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; @@ -2078,7 +2141,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, float gate_bias = 0.0f; float gate_limit = std::numeric_limits::infinity(); if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit) { - int expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, + int expert = findTotalEltsLessThanTarget<128>(expert_first_token_offset, num_experts_per_node, (int64_t)token + 1) - 1; gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f; @@ -2126,14 +2189,17 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= template + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, - int num_experts_per_node, int64_t inter_size, + int num_experts_per_node, int64_t inter_size_real_, float const* fc2_act_global_scale, bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params) { + constexpr int inter_size = 2048; + if (inter_size != inter_size_real_) { asm("trap;"); } + #ifdef ENABLE_FP4 constexpr bool IsNVFP4 = std::is_same_v && @@ -2186,7 +2252,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, activation_params.swiglu_limit) { // TODO this is almost certainly faster as a linear scan expert = - findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - + findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f; gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f; @@ -2218,9 +2284,9 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto output_vec = reinterpret_cast(safe_inc_ptr(output, output_offset)); auto bias_ptr_vec = reinterpret_cast(bias_ptr + bias_offset); int64_t const start_offset = tid; - int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; + constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK; assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0); - int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; + constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD; assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0); int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD; @@ -2228,6 +2294,8 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, fn.alpha = gate_alpha; fn.beta = gate_beta; fn.limit = gate_limit; + +#pragma unroll for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto fc1_value = arrayConvert(gemm_result_vec[elem_index + gated_off_vec]); @@ -2358,30 +2426,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 auto fn = [&]() { auto fn = [&](auto block_scaling_type) { - auto fn_list = std::array{ - &doActivationKernel, - decltype(block_scaling_type)::value>, // Gelu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Relu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Silu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Swiglu - &doActivationKernel, - decltype(block_scaling_type)::value>, // Geglu - &doActivationKernel, // SwigluBias - &doActivationKernel, - decltype(block_scaling_type)::value> // Identity - - }; - return fn_list[static_cast(activation_type.activation_type)]; + if (num_experts_per_node == 128) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 128; + auto fn_list = std::array{ + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu + &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity + + }; + return fn_list[static_cast(activation_type.activation_type)]; + } + if (num_experts_per_node == 64) { + constexpr int NUM_EXPERTS_PER_NODE_CONST = 64; + auto fn_list = std::array{ + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu + &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity + + }; + return fn_list[static_cast(activation_type.activation_type)]; + } + printf("unsupported num_experts_per_node\n"); + exit(1); }; auto NVFP4 = tensorrt_llm::common::ConstExprWrapper< TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType, From 348a53665576e8b769064b7251e2ad1a4de2bb91 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 21:06:03 +0800 Subject: [PATCH 096/103] cp change-block-thread, pragma-unroll, mv-if-check --- .../cutlass_fused_moe_kernels.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 76edcfdd0b..33a981488a 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template and std::is_same_v)) { + printf("finalizeMoeRoutingKernel see unsupported dtype\n"); + asm("trap;"); +} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1847,16 +1851,19 @@ void finalizeMoeRoutingKernel( for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); + +#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - if (expert_id < 0 || expert_id >= num_experts_per_node) { - continue; - } int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + if (expert_id < 0 || expert_id >= num_experts_per_node) { + continue; + } + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; @@ -1884,6 +1891,7 @@ void finalizeMoeRoutingKernel( asm volatile("griddepcontrol.launch_dependents;"); #endif } +} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From 0e62fa89f3cfa1b96adbc0c077f30591a072ef66 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 22:24:14 +0800 Subject: [PATCH 097/103] Revert "cp change-block-thread, pragma-unroll, mv-if-check" This reverts commit 348a53665576e8b769064b7251e2ad1a4de2bb91. --- .../cutlass_fused_moe_kernels.cuh | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 33a981488a..76edcfdd0b 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 128; +constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template and std::is_same_v)) { - printf("finalizeMoeRoutingKernel see unsupported dtype\n"); - asm("trap;"); -} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1851,19 +1847,16 @@ if constexpr (not (std::is_same_v and std::is_sam for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); - -#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - - int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; @@ -1891,7 +1884,6 @@ if constexpr (not (std::is_same_v and std::is_sam asm volatile("griddepcontrol.launch_dependents;"); #endif } -} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From 941b68ca99906b939623f5083afa710c20961afe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 19 Sep 2025 22:24:58 +0800 Subject: [PATCH 098/103] enable all except for 21:06 --- .../cutlass_fused_moe_kernels.cuh | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 76edcfdd0b..ad960d4190 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1813,6 +1813,10 @@ void finalizeMoeRoutingKernel( ScaleBiasType const* bias, float const* scales, int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_, int const num_experts_per_node, int const start_expert_id) { +if constexpr (not (std::is_same_v and std::is_same_v)) { + printf("finalizeMoeRoutingKernel see unsupported dtype\n"); + asm("trap;"); +} else { constexpr int experts_per_token = 8; if (experts_per_token != experts_per_token_real_) { asm("trap;"); } @@ -1847,6 +1851,11 @@ void finalizeMoeRoutingKernel( for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); + + int4 input_val_buf[experts_per_token]; + uint32_t enable_input_buf = 0; + +#pragma unroll for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; @@ -1862,28 +1871,46 @@ void finalizeMoeRoutingKernel( continue; } - float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - ComputeElem expert_result = - arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - if (bias) { - auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; - expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); - } +// ComputeElem expert_result = +// arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); + static_assert(sizeof(expanded_permuted_rows_row_ptr[0]) == sizeof(int4)); + input_val_buf[k_idx] = *reinterpret_cast(expanded_permuted_rows_row_ptr + elem_index); + enable_input_buf |= 1 << k_idx; + } + +#pragma unroll + for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { + if (not (enable_input_buf & (1 << k_idx))) continue; + + int64_t const k_offset = original_row * experts_per_token + k_idx; + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + + int4 input_val = input_val_buf[k_idx]; + ComputeElem expert_result = arrayConvert(*reinterpret_cast(&input_val)); +// if (bias) { +// auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; +// expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); +// } thread_output = thread_output + row_scale * expert_result; } - OutputElem output_elem = arrayConvert(thread_output); - reduced_row_ptr_v[elem_index] = output_elem; +// OutputElem output_elem = arrayConvert(thread_output); +// reduced_row_ptr_v[elem_index] = output_elem; + // TODO alignment issue? + __align__(16) OutputElem output_elem_original = arrayConvert(thread_output); + int4 output_elem = *reinterpret_cast(&output_elem_original); + static_assert(sizeof(reduced_row_ptr_v[0]) == sizeof(int4)); + *reinterpret_cast(reduced_row_ptr_v + elem_index) = output_elem; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } +} // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip From d504c6181a04b2afac69bd9c0bb51b2a5856b161 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 20 Sep 2025 08:41:06 +0800 Subject: [PATCH 099/103] Revert "feat: Benchmark mm_fp4 mxfp4 support and gemm autotune support. Restore mm_fp4 API behavior (#1706)" This reverts commit e8f5460f6f798efeadb769e3ad5e7e581629b03a. --- .../routines/flashinfer_benchmark_utils.py | 1 - benchmarks/routines/gemm.py | 133 +++--------------- flashinfer/gemm.py | 2 +- 3 files changed, 19 insertions(+), 117 deletions(-) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 9a81a64439..8b154ae527 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -39,7 +39,6 @@ "out_dtype", "mma_sm", "use_128x4_sf_layout", - "use_nvfp4", ], "moe": [ "num_tokens", diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 9190e7c30f..b883d7d079 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -6,7 +6,6 @@ from einops import einsum import flashinfer -from flashinfer.autotuner import autotune from flashinfer.testing.utils import ( bench_gpu_time, dequantize_fp8, @@ -138,17 +137,6 @@ def parse_gemm_args(line, parser): action="store_true", help="Use 128x4 SF layout for the input and mat2.", ) - parser.add_argument( - "--use_nvfp4", - action="store_true", - help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.", - ) - parser.add_argument( - "--autotune", - action="store_true", - default=False, - help=("Enable autotuner warmup for supported routines (mm_fp4 and bmm_fp8)."), - ) args = parser.parse_args(line) if args.verbose >= 1: @@ -565,9 +553,6 @@ def testBmmFp8(args): backends = args.backends is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - autotune_supported_backends = [ - "cutlass", - ] input_dtype = dtype_str_to_torch_dtype(args.input_dtype) if input_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: @@ -588,19 +573,6 @@ def testBmmFp8(args): ) ## Done parsing input arguments - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) - - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return - ## Prepare input tensors input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16) input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) @@ -638,18 +610,6 @@ def run_backend(backend): reference_output = torch.bmm(input, mat2) has_reference_output = True - if getattr(args, "autotune", False): - warmup_iters = ( - args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 - ) - for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for bmm_fp8: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} @@ -695,14 +655,6 @@ def run_backend(backend): res = [] for backend in backends: - backend_name = backend + ( - "_autotune" - if ( - getattr(args, "autotune", False) - and backend in autotune_supported_backends - ) - else "" - ) if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) @@ -714,7 +666,7 @@ def run_backend(backend): ) tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec - print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) if args.output_path is not None: cur_res = defaultdict(str) @@ -730,7 +682,7 @@ def run_backend(backend): cur_res["input_dtype"] = input_dtype cur_res["mat2_dtype"] = mat2_dtype cur_res["out_dtype"] = res_dtype - cur_res["backend"] = backend_name + cur_res["backend"] = backend cur_res["case_tag"] = args.case_tag res.append(cur_res) return res @@ -773,8 +725,6 @@ def testMmFp4(args): is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout - use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cutlass", "trtllm"] res_dtype = dtype_str_to_torch_dtype(args.out_dtype) if res_dtype not in [torch.bfloat16, torch.float16]: @@ -786,42 +736,24 @@ def testMmFp4(args): if "trtllm" in backends: remove_trtllm = False if res_dtype == torch.float16: - print("[INFO] trtllm backend does not support float16 output") + print("[INFO] trtllm backend does not suppot float16 output") remove_trtllm = True if remove_trtllm: backends.remove("trtllm") - if not use_nvfp4: - print( - "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("trtllm") if "cutlass" in backends: remove_cutlass = False if not use_128x4_sf_layout: - print("[INFO] cutlass backend does not support use_128x4_sf_layout=False") + print("[INFO] cutlass backend does not suppot use_128x4_sf_layout=False") remove_cutlass = True - if not use_nvfp4: - print( - "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" - ) - backends.remove("cutlass") if remove_cutlass: backends.remove("cutlass") if "cudnn" in backends: remove_cudnn = False if not use_128x4_sf_layout: - print("[INFO] cudnn backend does not support use_128x4_sf_layout=False") + print("[INFO] cudnn backend does not suppot use_128x4_sf_layout=False") remove_cudnn = True if remove_cudnn: backends.remove("cudnn") - if getattr(args, "autotune", False): - backends_to_remove = [] - for cur_backend in backends: - if cur_backend not in autotune_supported_backends: - print(f"[INFO] {cur_backend} backend does not support autotune") - backends_to_remove.append(cur_backend) - for cur_backend in backends_to_remove: - backends.remove(cur_backend) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -838,20 +770,15 @@ def testMmFp4(args): global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max() global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max() - if use_nvfp4: - input_fp4, input_inv_s = flashinfer.nvfp4_quantize( - input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False - ) - mat2_fp4, mat2_inv_s = flashinfer.nvfp4_quantize( - mat2, - global_sf_mat2, - sfLayout=flashinfer.SfLayout.layout_128x4, - do_shuffle=False, - ) - else: # mxfp4 - input_fp4, input_inv_s = flashinfer.mxfp4_quantize(input) - mat2_fp4, mat2_inv_s = flashinfer.mxfp4_quantize(mat2) - + input_fp4, input_inv_s = flashinfer.nvfp4_quantize( + input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False + ) + mat2_fp4, mat2_inv_s = flashinfer.nvfp4_quantize( + mat2, + global_sf_mat2, + sfLayout=flashinfer.SfLayout.layout_128x4, + do_shuffle=False, + ) if "trtllm" in backends: mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize( mat2, @@ -866,7 +793,7 @@ def testMmFp4(args): print(f"[VVERBOSE] {mat2_fp4.shape = }") print(f"[VVERBOSE] {mat2_fp4.dtype = }") - alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None + alpha = 1.0 / (global_sf_input * global_sf_mat2) # res = torch.empty([m, n], device="cuda", dtype=res_dtype) def run_backend(backend): @@ -878,12 +805,9 @@ def run_backend(backend): b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, alpha=alpha, out_dtype=res_dtype, - block_size=16 - if use_nvfp4 - else 32, # nvfp4 only supports 16; mxfp4 only supports 32. + block_size=16, # Only supports 16 use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, - use_nvfp4=use_nvfp4, ) else: raise ValueError(f"Unsupported backend: {backend}") @@ -893,18 +817,6 @@ def run_backend(backend): reference_output = torch.mm(input, mat2.T) has_reference_output = True - if getattr(args, "autotune", False): - warmup_iters = ( - args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 - ) - for cur_backend in backends: - if cur_backend in autotune_supported_backends: - if args.verbose >= 1: - print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters") - with autotune(True): - for _ in range(warmup_iters): - run_backend(cur_backend) - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} @@ -944,14 +856,6 @@ def run_backend(backend): res = [] for backend in backends: - backend_name = backend + ( - "_autotune" - if ( - getattr(args, "autotune", False) - and backend in autotune_supported_backends - ) - else "" - ) if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) @@ -961,7 +865,7 @@ def run_backend(backend): ) # 0.5 for fp4 tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec - print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) if args.output_path is not None: cur_res = defaultdict(str) @@ -975,8 +879,7 @@ def run_backend(backend): cur_res["k"] = k cur_res["out_dtype"] = res_dtype cur_res["use_128x4_sf_layout"] = use_128x4_sf_layout - cur_res["backend"] = backend_name - cur_res["use_nvfp4"] = use_nvfp4 + cur_res["backend"] = backend cur_res["case_tag"] = args.case_tag res.append(cur_res) return res diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 98424a1bb0..4dfa970860 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -2005,7 +2005,7 @@ def mm_fp4( block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", - use_nvfp4: bool = True, + use_nvfp4: bool = False, ) -> torch.Tensor: r"""MM FP4 From 822ae9b52afb1b6f52b1216ed054f3a0d2a8a2b5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 20 Sep 2025 08:41:29 +0800 Subject: [PATCH 100/103] Revert "Enabled alpha with the mx_fp4 format (#1688)" This reverts commit bc4239316e2a27514dfa2d89a21faa365efbac4c. # Conflicts: # tests/test_mm_fp4.py --- flashinfer/gemm.py | 29 ++++++++++------------------- tests/test_mm_fp4.py | 22 +++++++--------------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 4dfa970860..ee8861b54e 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1652,7 +1652,6 @@ def build_cudnn_gemm_block_scale_dequantize_graph( o_type, block_size, device, - alpha, use_nvfp4, ): _check_cudnn_availability() @@ -1705,7 +1704,8 @@ def build_cudnn_gemm_block_scale_dequantize_graph( c_final_cudnn_tensor = c_tensor - if alpha is not None: + # if use_nvfp4 is True, we need to multiply the output by the global scale + if use_nvfp4: global_scale_cudnn_tensor = graph.tensor( name="global_scale", dim=(1, 1, 1), @@ -1734,7 +1734,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph( # WAR: The alpha (contains the global scale) is not supported by the cuBLAS backend (eng0) # in older cuDNN versions, so we deselect it. - if (alpha is not None) and (not _is_cublas_fp4_available_in_cudnn()): + if use_nvfp4 and not _is_cublas_fp4_available_in_cudnn(): graph.deselect_engines(["eng0"]) graph.check_support() graph.build_plans() @@ -1743,14 +1743,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph( def execute_cudnn_gemm_fp4_graph( - graph, - a, - b, - a_descale, - b_descale, - alpha, - c_final, - workspace_buffer, + graph, a, b, a_descale, b_descale, alpha, c_final, workspace_buffer, use_nvfp4 ): variant_pack = { UIDs.A_UID.value: a.view(_get_native_fp4_dtype()), @@ -1760,7 +1753,7 @@ def execute_cudnn_gemm_fp4_graph( UIDs.O_UID.value: c_final, } - if alpha is not None: + if use_nvfp4: variant_pack[UIDs.ALPHA_UID.value] = alpha.view(torch.float) if workspace_buffer.numel() < graph.get_workspace_size(): @@ -2005,7 +1998,6 @@ def mm_fp4( block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", - use_nvfp4: bool = False, ) -> torch.Tensor: r"""MM FP4 @@ -2024,7 +2016,7 @@ def mm_fp4( Block scale tensor for B, shape (k, n // block_size), float8_e4m3fn or uint8. alpha: Optional[torch.Tensor] - Global scale tensor, float scalar. + Global scale tensor, float scalar in case of nvfp4 quantization. None in case of mxfp4 quantization. out_dtype: torch.dtype Output dtype, bf16 or fp16. @@ -2041,9 +2033,6 @@ def mm_fp4( backend: Literal["cudnn", "trtllm", "cutlass"] Backend to use, defaults to "cudnn". - use_nvfp4: bool - Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False. - Notes ----- When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. @@ -2068,6 +2057,9 @@ def mm_fp4( >>> out.shape torch.Size([48, 256]) """ + # nvfp4 quantization if alpha provided, mxfp4 quantization if no alpha provided + use_nvfp4 = alpha is not None + # pre-check the input tensor, block scale tensor and alpha tensor if a.ndim != 2 or b.ndim != 2: raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") @@ -2155,13 +2147,12 @@ def mm_fp4( _torch_data_type_to_cudnn_data_type(out_dtype), block_size, a.device, - alpha, use_nvfp4, ) # execute the fp4 cudnn graph execute_cudnn_gemm_fp4_graph( - graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer + graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, use_nvfp4 ) elif backend == "trtllm": if out_dtype != torch.bfloat16: diff --git a/tests/test_mm_fp4.py b/tests/test_mm_fp4.py index 6e8514be73..2c0d01f3dd 100644 --- a/tests/test_mm_fp4.py +++ b/tests/test_mm_fp4.py @@ -8,7 +8,6 @@ nvfp4_quantize, mxfp4_quantize, ) -from flashinfer.utils import get_compute_capability # TODO: Consdier splitting this function up for the various backends @@ -19,23 +18,17 @@ @pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) @pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) @pytest.mark.parametrize("auto_tuning", [False, True]) -@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"]) +@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4"]) def test_mm_fp4( m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type ): - use_nvfp4 = fp4_type == "nvfp4" - - if backend == "trtllm": - if res_dtype == torch.float16: - pytest.skip("Skipping test for trtllm fp4 with float16") - compute_capability = get_compute_capability(torch.device(device="cuda")) - if compute_capability[0] in [11, 12]: - pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") + if backend == "trtllm" and res_dtype == torch.float16: + pytest.skip("Skipping test for trtllm fp4 with float16") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") if auto_tuning and backend == "cudnn": pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True") - if not use_nvfp4 and backend != "cudnn": + if fp4_type == "mxfp4" and backend != "cudnn": pytest.skip("mx_fp4 is only supported for cudnn backend") input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) @@ -48,8 +41,9 @@ def test_mm_fp4( # for trtllm, we need to shuffle mat2 because we swap A, B. do_shuffle_b = backend == "trtllm" + use_nvfp4 = fp4_type == "nvfp4" block_size = 16 if use_nvfp4 else 32 - has_alpha = fp4_type == "mxfp4_alpha" or fp4_type == "nvfp4" + alpha = None # None in case of mxfp4 if use_nvfp4: input_fp4, input_inv_s = nvfp4_quantize( @@ -61,12 +55,11 @@ def test_mm_fp4( sfLayout=SfLayout.layout_128x4, do_shuffle=do_shuffle_b, ) + alpha = 1.0 / (global_sf_input * global_sf_mat2) else: input_fp4, input_inv_s = mxfp4_quantize(input) mat2_fp4, mat2_inv_s = mxfp4_quantize(mat2) - alpha = 1.0 / (global_sf_input * global_sf_mat2) if has_alpha else None - reference = torch.mm(input, mat2.T) res = torch.empty([m, n], device="cuda", dtype=res_dtype) @@ -83,7 +76,6 @@ def test_mm_fp4( block_size=block_size, use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, - use_nvfp4=use_nvfp4, ) cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) From 299caf51e53a6d245d960de2949fe8963f1fac54 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 20 Sep 2025 11:09:59 +0800 Subject: [PATCH 101/103] enable change-thread-block --- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9bc2e8c36d..1377cdb255 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) { // (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the // source matrix, we simply take the modulus of the expanded index. -constexpr static int EXPAND_THREADS_PER_BLOCK = 256; +constexpr static int EXPAND_THREADS_PER_BLOCK = 128; template Date: Sat, 20 Sep 2025 12:21:28 +0800 Subject: [PATCH 102/103] enable mv-unpermuted_row_to_permuted_row --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 1377cdb255..32a2d33146 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1859,13 +1859,14 @@ if constexpr (not (std::is_same_v and std::is_sam for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - if (expert_id < 0 || expert_id >= num_experts_per_node) { - continue; - } int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + if (expert_id < 0 || expert_id >= num_experts_per_node) { + continue; + } + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue; From 68b8e6db730e81c9ac5ae3ed55b4b01dde5df0a5 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Sat, 20 Sep 2025 14:21:29 +0800 Subject: [PATCH 103/103] Revert "enable mv-unpermuted_row_to_permuted_row" This reverts commit d83a3cb1460aa68cff6dbaad8a29bab17bccf3cf. --- .../cutlass_backend/cutlass_fused_moe_kernels.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 32a2d33146..1377cdb255 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1859,14 +1859,13 @@ if constexpr (not (std::is_same_v and std::is_sam for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; - - int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; - if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t expanded_rows = num_rows * experts_per_token; if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) { continue;