From a5390dfe5b64076ea56c3aba690348b6759d390e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 10:08:34 +0800 Subject: [PATCH 01/28] more --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cc701bd8b..7d2eaf51d 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,8 @@ def generate_build_meta(aot_build_meta: dict) -> None: "einops", "nvidia-nvshmem-cu12", "nvidia-cudnn-cu12", - "nvidia-cudnn-frontend", + # NOTE MODIFIED rm + # "nvidia-cudnn-frontend", ] generate_build_meta({}) From 94c6d273deeb31f92e33e58ada4bb2b1b17b6067 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:29:47 +0800 Subject: [PATCH 02/28] more --- benchmarks/bench_cutlass_fused_moe.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 165c55a75..940d0f387 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -35,7 +35,9 @@ 96, 128, 256, + 384, # NOTE ADD 512, + 768, # NOTE ADD 1024, 1536, 2048, @@ -53,6 +55,27 @@ FP8_DTYPE = torch.float8_e4m3fn test_configs = [ + # NOTE MODIFIED ADD + *[ + { + "hidden_size": 7168, + "num_experts": num_experts, + "top_k": 8, + "intermediate_size": 2048, + } + for num_experts in [ + 288 // 1, + 288 // 2, + 288 // 4, + 288 // 8, + 288 // 16, + 288 // 32, + 288 // 48, + 288 // 72, + ] + ], + + # --- old --- { "hidden_size": 7168, "num_experts": 256, @@ -199,6 +222,7 @@ def bench_cutlass_fused_moe( f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" ) print(f"execution time: {ms}ms") + print(f"hi {selected_experts=}") if __name__ == "__main__": From 64b0032eff293ed9b19bc3e3889325b4c55ff5fc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:31:41 +0800 Subject: [PATCH 03/28] more --- benchmarks/bench_cutlass_fused_moe.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 940d0f387..7891d786f 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import json import torch from torch.nn import functional as F @@ -218,11 +219,18 @@ def bench_cutlass_fused_moe( output=flash_output, ) ) - print( - f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" - ) - print(f"execution time: {ms}ms") - print(f"hi {selected_experts=}") + # NOTE MODIFIED + print(f"MAIN_OUTPUT=" + json.dumps(dict( + batch_size=batch_size, + num_experts=num_experts, + top_k=top_k, + intermediate_size=intermediate_size, + execution_time_us=ms * 1000, + ))) + # print( + # f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" + # ) + # print(f"execution time: {ms}ms") if __name__ == "__main__": From cfcc7ce9012dd19a168528e7b481278992ce3427 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:35:00 +0800 Subject: [PATCH 04/28] more --- benchmarks/bench_cutlass_fused_moe.py | 116 ++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 7891d786f..907b386cc 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -14,6 +14,8 @@ limitations under the License. """ import json +import os +import sys import torch from torch.nn import functional as F @@ -23,6 +25,120 @@ import flashinfer.fused_moe as fused_moe from flashinfer import fp4_quantize +# ------------------------------------------------------------------------------------------------ + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + # Conflict with Nsight Systems + using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + print(f"{prof_lines=}") + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num) + + return tuple(kernel_times) if is_tuple else kernel_times[0] + + +# ------------------------------------------------------------------------------------------------ + BATCH_SIZES = [ 1, 2, From 0e63dbe3aa26ccb2fb996b3ed94f6c2e382b1f10 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:45:48 +0800 Subject: [PATCH 05/28] more --- benchmarks/bench_cutlass_fused_moe.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 907b386cc..00e856e2d 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -106,7 +106,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') - print(f"{prof_lines=}") + print(f"prof_lines=\n" + "\n".join(prof_lines)) kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) if not with_multiple_kernels: @@ -322,7 +322,21 @@ def bench_cutlass_fused_moe( input_sf=input_sf, output=flash_output, ) - ms = do_bench( + # NOTE MODIFIED + # ms = do_bench( + # lambda: fused_moe.cutlass_fused_moe( + # hidden_states, + # selected_experts.to(torch.int), + # routing_weights, + # w1_q.contiguous().view(torch.long), + # w2_q.contiguous().view(torch.long), + # otype, + # quant_scales=quant_scales, + # input_sf=input_sf, + # output=flash_output, + # ) + # ) + bench_kineto( lambda: fused_moe.cutlass_fused_moe( hidden_states, selected_experts.to(torch.int), @@ -333,8 +347,10 @@ def bench_cutlass_fused_moe( quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - ) + ), + kernel_names="what", ) + # NOTE MODIFIED print(f"MAIN_OUTPUT=" + json.dumps(dict( batch_size=batch_size, From 72ec2c43bac06175c0dc55e52cd0a804f5f88ced Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:50:03 +0800 Subject: [PATCH 06/28] more --- benchmarks/bench_cutlass_fused_moe.py | 56 ++++++++++++++------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 00e856e2d..6a3123f32 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -16,6 +16,7 @@ import json import os import sys +import time import torch from torch.nn import functional as F @@ -140,26 +141,27 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # ------------------------------------------------------------------------------------------------ BATCH_SIZES = [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 384, # NOTE ADD - 512, + # TODO more + # 1, + # 2, + # 4, + # 8, + # 16, + # 24, + # 32, + # 48, + # 64, + # 96, + # 128, + # 256, + # 384, # NOTE ADD + # 512, 768, # NOTE ADD - 1024, - 1536, - 2048, - 3072, - 4096, + # 1024, + # 1536, + # 2048, + # 3072, + # 4096, ] configs = [] @@ -181,14 +183,15 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, "intermediate_size": 2048, } for num_experts in [ - 288 // 1, - 288 // 2, - 288 // 4, - 288 // 8, - 288 // 16, + # TODO more + # 288 // 1, + # 288 // 2, + # 288 // 4, + # 288 // 8, + # 288 // 16, 288 // 32, - 288 // 48, - 288 // 72, + # 288 // 48, + # 288 // 72, ] ], @@ -349,6 +352,7 @@ def bench_cutlass_fused_moe( output=flash_output, ), kernel_names="what", + trace_path=os.environ.get("BENCH_KINETO_TRACE_DIR") + "/" + str(time.time()) + ".json.gz", ) # NOTE MODIFIED From 8caafdedb39440f5fb66ddbb9610c0bd71593c1f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:50:56 +0800 Subject: [PATCH 07/28] more --- benchmarks/bench_cutlass_fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 6a3123f32..4afa7843c 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -116,6 +116,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # Save chrome traces if trace_path is not None: + print(f"export_chrome_trace to {trace_path=}") profiler.export_chrome_trace(trace_path) # Return average kernel times From 5cbc7d0eaa2f98c8277ef752b3137d47843cb501 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:52:27 +0800 Subject: [PATCH 08/28] more --- benchmarks/bench_cutlass_fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 4afa7843c..ab5590699 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -340,6 +340,7 @@ def bench_cutlass_fused_moe( # output=flash_output, # ) # ) + trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR") bench_kineto( lambda: fused_moe.cutlass_fused_moe( hidden_states, @@ -353,7 +354,7 @@ def bench_cutlass_fused_moe( output=flash_output, ), kernel_names="what", - trace_path=os.environ.get("BENCH_KINETO_TRACE_DIR") + "/" + str(time.time()) + ".json.gz", + trace_path=f"{trace_dir}/{time.time()}.json.gz" if trace_dir else None, ) # NOTE MODIFIED From ab252c7d92c6f58e0a3033718568dfcd0e27c854 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 11:54:02 +0800 Subject: [PATCH 09/28] more --- benchmarks/bench_cutlass_fused_moe.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index ab5590699..abd84cab4 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -108,17 +108,19 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') print(f"prof_lines=\n" + "\n".join(prof_lines)) - kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names - assert all([isinstance(name, str) for name in kernel_names]) - if not with_multiple_kernels: - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + # NOTE MOVED # Save chrome traces if trace_path is not None: print(f"export_chrome_trace to {trace_path=}") profiler.export_chrome_trace(trace_path) + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + # Return average kernel times units = {'ms': 1e3, 'us': 1e6} kernel_times = [] From 862900829436a0b1c041abe633bb8bc306ef4715 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:27:34 +0800 Subject: [PATCH 10/28] more --- benchmarks/bench_cutlass_fused_moe.py | 118 +------------------------- flashinfer/testing/utils.py | 115 +++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 117 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index abd84cab4..ab03589ba 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -25,123 +25,7 @@ import flashinfer import flashinfer.fused_moe as fused_moe from flashinfer import fp4_quantize - -# ------------------------------------------------------------------------------------------------ - -class empty_suppress: - def __enter__(self): - return self - - def __exit__(self, *_): - pass - - -class suppress_stdout_stderr: - def __enter__(self): - self.outnull_file = open(os.devnull, 'w') - self.errnull_file = open(os.devnull, 'w') - - self.old_stdout_fileno_undup = sys.stdout.fileno() - self.old_stderr_fileno_undup = sys.stderr.fileno() - - self.old_stdout_fileno = os.dup(sys.stdout.fileno()) - self.old_stderr_fileno = os.dup(sys.stderr.fileno()) - - self.old_stdout = sys.stdout - self.old_stderr = sys.stderr - - os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) - os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) - - sys.stdout = self.outnull_file - sys.stderr = self.errnull_file - return self - - def __exit__(self, *_): - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr - - os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) - os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) - - os.close(self.old_stdout_fileno) - os.close(self.old_stderr_fileno) - - self.outnull_file.close() - self.errnull_file.close() - - -def bench_kineto(fn, kernel_names, num_tests: int = 30, - suppress_kineto_output: bool = False, - trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): - # Conflict with Nsight Systems - using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) - - # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle - flush_l2_size = int(8e9 // 4) - - # For some auto-tuning kernels with prints - fn() - - # Profile - suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress - with suppress(): - schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() - with profiler: - for i in range(2): - for _ in range(num_tests): - if flush_l2: - torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() - fn() - - if not using_nsys: - profiler.step() - - # Return 1 if using Nsight Systems - if using_nsys: - return 1 - - # Parse the profiling table - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tuple = isinstance(kernel_names, tuple) - prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') - print(f"prof_lines=\n" + "\n".join(prof_lines)) - - # NOTE MOVED - # Save chrome traces - if trace_path is not None: - print(f"export_chrome_trace to {trace_path=}") - profiler.export_chrome_trace(trace_path) - - kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names - assert all([isinstance(name, str) for name in kernel_names]) - if not with_multiple_kernels: - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' - - # Return average kernel times - units = {'ms': 1e3, 'us': 1e6} - kernel_times = [] - for name in kernel_names: - total_time = 0 - total_num = 0 - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - num_str = line.split()[-1] - for unit, scale in units.items(): - if unit in time_str: - total_time += float(time_str.replace(unit, '')) / scale * int(num_str) - total_num += int(num_str) - break - kernel_times.append(total_time / total_num) - - return tuple(kernel_times) if is_tuple else kernel_times[0] - - -# ------------------------------------------------------------------------------------------------ +from flashinfer.testing.utils import bench_kineto BATCH_SIZES = [ # TODO more diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index d160f8aba..3da6ad490 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -206,3 +206,118 @@ def dequantize_fp8(x, x_scale, scale_major_mode): x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") return out + + +# ----------------------- copy and modified from DeepGEMM ----------------------- + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + # Conflict with Nsight Systems + using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + print(f"prof_lines=\n" + "\n".join(prof_lines)) + + # NOTE MOVED + # Save chrome traces + if trace_path is not None: + print(f"export_chrome_trace to {trace_path=}") + profiler.export_chrome_trace(trace_path) + + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num) + + return tuple(kernel_times) if is_tuple else kernel_times[0] From 46249131da2a5508d6788a46cfa675bf72827a33 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:30:11 +0800 Subject: [PATCH 11/28] more --- flashinfer/testing/utils.py | 63 ++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index 3da6ad490..609067728 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -13,8 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from typing import Tuple +import json +import os +import sys +import tempfile +from pathlib import Path +from typing import Tuple, Union, Optional +import torch.distributed as dist import torch from einops import rearrange, reduce, repeat @@ -208,8 +213,6 @@ def dequantize_fp8(x, x_scale, scale_major_mode): return out -# ----------------------- copy and modified from DeepGEMM ----------------------- - class empty_suppress: def __enter__(self): return self @@ -253,10 +256,10 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names, num_tests: int = 30, - suppress_kineto_output: bool = False, - trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): +# copy and modified from DeepGEMM and DeepEP +def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, + num_kernels_per_period: int = 1): # Conflict with Nsight Systems using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) @@ -273,6 +276,12 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() with profiler: for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() @@ -291,7 +300,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') print(f"prof_lines=\n" + "\n".join(prof_lines)) - # NOTE MOVED # Save chrome traces if trace_path is not None: print(f"export_chrome_trace to {trace_path=}") @@ -299,25 +307,36 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) - if not with_multiple_kernels: - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' - # Return average kernel times + # Return average kernel durations units = {'ms': 1e3, 'us': 1e6} - kernel_times = [] + kernel_durations = [] for name in kernel_names: - total_time = 0 - total_num = 0 for line in prof_lines: if name in line: time_str = line.split()[-2] - num_str = line.split()[-1] for unit, scale in units.items(): if unit in time_str: - total_time += float(time_str.replace(unit, '')) / scale * int(num_str) - total_num += int(num_str) + kernel_durations.append(float(time_str.replace(unit, '')) / scale) break - kernel_times.append(total_time / total_num) - - return tuple(kernel_times) if is_tuple else kernel_times[0] + break + + # Expand the kernels by periods + if num_kernels_per_period > 1: + with tempfile.NamedTemporaryFile(suffix='.json') as tmp: + profiler.export_chrome_trace(tmp.name) + profile_data = json.loads(Path(tmp.name).read_text()) + + for i, kernel_name in enumerate(kernel_names): + events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']] + events = sorted(events, key=lambda event: event['ts']) + durations = [event['dur'] / 1e6 for event in events] + assert len(durations) % num_kernels_per_period == 0 + num_kernel_patterns = len(durations) // num_kernels_per_period + kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns + for j in range(num_kernels_per_period)] + + # Return execution durations + return kernel_durations if is_tuple else kernel_durations[0] From 9de03871904229c76db6071e941e1126e4d7bfb9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:33:24 +0800 Subject: [PATCH 12/28] more --- benchmarks/bench_cutlass_fused_moe.py | 5 +++-- flashinfer/testing/utils.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index ab03589ba..7688db420 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -239,8 +239,9 @@ def bench_cutlass_fused_moe( input_sf=input_sf, output=flash_output, ), - kernel_names="what", - trace_path=f"{trace_dir}/{time.time()}.json.gz" if trace_dir else None, + kernel_names="cutlass13device_kernelINS_4gemm6kernel", + num_kernels_per_period=2, + trace_path=f"{trace_dir}/{time.time()}.trace.json.gz" if trace_dir else None, ) # NOTE MODIFIED diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index 609067728..a4fa41de7 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -330,7 +330,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr profile_data = json.loads(Path(tmp.name).read_text()) for i, kernel_name in enumerate(kernel_names): - events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']] + events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] events = sorted(events, key=lambda event: event['ts']) durations = [event['dur'] / 1e6 for event in events] assert len(durations) % num_kernels_per_period == 0 From c89cf7c861b8191fb28a39432bd2d18c56d7287f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:36:44 +0800 Subject: [PATCH 13/28] more --- benchmarks/bench_cutlass_fused_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 7688db420..c5bae6e50 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -227,7 +227,7 @@ def bench_cutlass_fused_moe( # ) # ) trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR") - bench_kineto( + [time_gemm1, time_gemm2] = bench_kineto( lambda: fused_moe.cutlass_fused_moe( hidden_states, selected_experts.to(torch.int), @@ -250,7 +250,8 @@ def bench_cutlass_fused_moe( num_experts=num_experts, top_k=top_k, intermediate_size=intermediate_size, - execution_time_us=ms * 1000, + time_gemm1_us=time_gemm1 * 1e6, + time_gemm2_us=time_gemm2 * 1e6, ))) # print( # f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" From adf2cb178c80a084cffbd1cd3bcb94b942e5634f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:38:42 +0800 Subject: [PATCH 14/28] more --- flashinfer/testing/utils.py | 40 +++++++++++++------------------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index a4fa41de7..654303423 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -298,7 +298,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') - print(f"prof_lines=\n" + "\n".join(prof_lines)) + # print(f"prof_lines=\n" + "\n".join(prof_lines)) # Save chrome traces if trace_path is not None: @@ -310,33 +310,21 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr for name in kernel_names: assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' - # Return average kernel durations - units = {'ms': 1e3, 'us': 1e6} - kernel_durations = [] - for name in kernel_names: - for line in prof_lines: - if name in line: - time_str = line.split()[-2] - for unit, scale in units.items(): - if unit in time_str: - kernel_durations.append(float(time_str.replace(unit, '')) / scale) - break - break + kernel_durations = [None] * len(kernel_names) # Expand the kernels by periods - if num_kernels_per_period > 1: - with tempfile.NamedTemporaryFile(suffix='.json') as tmp: - profiler.export_chrome_trace(tmp.name) - profile_data = json.loads(Path(tmp.name).read_text()) - - for i, kernel_name in enumerate(kernel_names): - events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] - events = sorted(events, key=lambda event: event['ts']) - durations = [event['dur'] / 1e6 for event in events] - assert len(durations) % num_kernels_per_period == 0 - num_kernel_patterns = len(durations) // num_kernels_per_period - kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns - for j in range(num_kernels_per_period)] + with tempfile.NamedTemporaryFile(suffix='.json') as tmp: + profiler.export_chrome_trace(tmp.name) + profile_data = json.loads(Path(tmp.name).read_text()) + + for i, kernel_name in enumerate(kernel_names): + events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] + events = sorted(events, key=lambda event: event['ts']) + durations = [event['dur'] / 1e6 for event in events] + assert len(durations) % num_kernels_per_period == 0 + num_kernel_patterns = len(durations) // num_kernels_per_period + kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns + for j in range(num_kernels_per_period)] # Return execution durations return kernel_durations if is_tuple else kernel_durations[0] From 57eda4a586c8575a60a9a5bb22d9efcb81a3a2db Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:39:34 +0800 Subject: [PATCH 15/28] more --- flashinfer/testing/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index 654303423..e07dd9e1c 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -307,8 +307,6 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) - for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' kernel_durations = [None] * len(kernel_names) From 8b606b8f77118679603d92e89654c4373aadfd97 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:40:13 +0800 Subject: [PATCH 16/28] more --- benchmarks/bench_cutlass_fused_moe.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index c5bae6e50..8a484b90c 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -83,18 +83,18 @@ ], # --- old --- - { - "hidden_size": 7168, - "num_experts": 256, - "top_k": 8, - "intermediate_size": 256, - }, - { - "hidden_size": 7168, - "num_experts": 32, - "top_k": 8, - "intermediate_size": 2048, - }, + # { + # "hidden_size": 7168, + # "num_experts": 256, + # "top_k": 8, + # "intermediate_size": 256, + # }, + # { + # "hidden_size": 7168, + # "num_experts": 32, + # "top_k": 8, + # "intermediate_size": 2048, + # }, ] From 7b783221e2f1d815a04fa570f337d66f6e836de6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:41:11 +0800 Subject: [PATCH 17/28] more --- benchmarks/bench_cutlass_fused_moe.py | 51 +++++++++++++-------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index 8a484b90c..1d76859dc 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -28,27 +28,26 @@ from flashinfer.testing.utils import bench_kineto BATCH_SIZES = [ - # TODO more - # 1, - # 2, - # 4, - # 8, - # 16, - # 24, - # 32, - # 48, - # 64, - # 96, - # 128, - # 256, - # 384, # NOTE ADD - # 512, + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 384, # NOTE ADD + 512, 768, # NOTE ADD - # 1024, - # 1536, - # 2048, - # 3072, - # 4096, + 1024, + 1536, + 2048, + 3072, + 4096, ] configs = [] @@ -70,13 +69,13 @@ "intermediate_size": 2048, } for num_experts in [ - # TODO more - # 288 // 1, - # 288 // 2, - # 288 // 4, - # 288 // 8, - # 288 // 16, + 288 // 1, + 288 // 2, + 288 // 4, + 288 // 8, + 288 // 16, 288 // 32, + # TODO support # 288 // 48, # 288 // 72, ] From 8c53c5963ca5a15118f9564ae39c9386daa412a8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:44:09 +0800 Subject: [PATCH 18/28] more --- flashinfer/testing/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index e07dd9e1c..649d9b887 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -298,7 +298,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') - # print(f"prof_lines=\n" + "\n".join(prof_lines)) + print(f"prof_lines=\n" + "\n".join(prof_lines)) # Save chrome traces if trace_path is not None: @@ -319,7 +319,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] events = sorted(events, key=lambda event: event['ts']) durations = [event['dur'] / 1e6 for event in events] - assert len(durations) % num_kernels_per_period == 0 + assert len(durations) % num_kernels_per_period == 0, f"{durations=}" num_kernel_patterns = len(durations) // num_kernels_per_period kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)] From c4c8a07850bdd21901fb9f52d5bb8d7b9fd5d07b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:47:02 +0800 Subject: [PATCH 19/28] more --- flashinfer/testing/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index 649d9b887..e8b8a5aae 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -319,10 +319,13 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] events = sorted(events, key=lambda event: event['ts']) durations = [event['dur'] / 1e6 for event in events] - assert len(durations) % num_kernels_per_period == 0, f"{durations=}" + if len(durations) % num_kernels_per_period != 0: + print(f"WARN: {len(durations)=} % {num_kernels_per_period=} != 0") + durations = durations[:len(durations) - (len(durations) % num_kernels_per_period)] num_kernel_patterns = len(durations) // num_kernels_per_period kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)] + print(f"{kernel_name=} {durations=}") # Return execution durations return kernel_durations if is_tuple else kernel_durations[0] From a19273cc9773234284fb82d85d22c2ebe7af4a49 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:56:27 +0800 Subject: [PATCH 20/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 benchmarks/bench_trtllm_gen_fused_moe.py.py diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py.py b/benchmarks/bench_trtllm_gen_fused_moe.py.py new file mode 100644 index 000000000..e69de29bb From a129862b772ca77ff13904f2379d7320d1f1c264 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:57:16 +0800 Subject: [PATCH 21/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py.py | 1078 +++++++++++++++++++ 1 file changed, 1078 insertions(+) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py.py b/benchmarks/bench_trtllm_gen_fused_moe.py.py index e69de29bb..b5279cf84 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py.py @@ -0,0 +1,1078 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math + +import pytest +import torch +from torch.nn import functional as F + +from flashinfer import ( + RoutingMethodType, + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, +) +from flashinfer.fused_moe import trtllm_fp4_block_scale_moe + + +class moe_args: + + def __init__( + self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + hidden_states_scale, + hidden_states_scale_global, + expert_logits, + gemm1_weights, + gemm1_scales, + gemm1_scales_global, + gemm2_weights, + gemm2_scales, + gemm2_scales_global, + permute_info, + use_routing_scales_on_input, + ): + self.num_tokens = num_tokens + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.top_k = top_k + self.padding = padding + self.hidden_states = hidden_states + self.hidden_states_scale = hidden_states_scale + self.hidden_states_scale_global = hidden_states_scale_global + self.expert_logits = expert_logits + self.gemm1_weights = gemm1_weights + self.gemm1_scales = gemm1_scales + self.gemm1_scales_global = gemm1_scales_global + self.gemm2_weights = gemm2_weights + self.gemm2_scales = gemm2_scales + self.gemm2_scales_global = gemm2_scales_global + self.permute_info = permute_info + self.use_routing_scales_on_input = use_routing_scales_on_input + + +class moe_args_dequant: + + def __init__( + self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + expert_logits, + gemm1_weights, + gemm2_weights, + permute_info, + use_routing_scales_on_input, + ): + self.num_tokens = num_tokens + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.top_k = top_k + self.padding = padding + self.hidden_states = hidden_states + self.expert_logits = expert_logits + self.gemm1_weights = gemm1_weights + self.gemm2_weights = gemm2_weights + self.permute_info = permute_info + self.use_routing_scales_on_input = use_routing_scales_on_input + + +def routing_reference(expertLogits, topK, padding): + originalDevice = expertLogits.device + expertLogits = expertLogits.cpu() + numTokens, numExperts = expertLogits.shape + assert topK <= numExperts + + numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64) + expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64) + expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64) + + topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1) + for tokenIdx in range(numTokens): + for k in range(topK): + expandedIdx = tokenIdx * topK + k + expertIndex = topKIndices[tokenIdx, k] + expandedTokenIdxToExpert[expandedIdx] = expertIndex + expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex] + numTokensPerExpert[expertIndex] += 1 + + paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64) + for ii in range(numExperts): + + def divUpMul(a, b): + return (a + b - 1) // b * b + + paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[ + ii + ] + divUpMul(numTokensPerExpert[ii], padding) + permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts] + + expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64) + permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) + permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) + for tokenIdx in range(numTokens): + for k in range(topK): + expandedIdx = tokenIdx * topK + k + expert = expandedTokenIdxToExpert[expandedIdx] + offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx] + offsetForExpert = paddedTokensPerExpertPrefixSum[expert] + permutedIdx = offsetForExpert + offsetWithinExpert + + expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx + permutedIdxToExpandedIdx[permutedIdx] = expandedIdx + permutedIdxToTokenIdx[permutedIdx] = tokenIdx + return { + "paddedTokensPerExpertPrefixSum": paddedTokensPerExpertPrefixSum.to( + originalDevice + ), + "permutedBufferSize": permutedBufferSize.item(), + "expandedTokenIdxToPermutedIdx": expandedTokenIdxToPermutedIdx.to( + originalDevice + ), + "permutedIdxToExpandedIdx": permutedIdxToExpandedIdx.to(originalDevice), + "numTokensPerExpert": numTokensPerExpert.to(originalDevice), + "expandedTokenIdxToExpert": expandedTokenIdxToExpert.to(originalDevice), + "topKLogits": topKLogits.to(originalDevice), + "permutedIdxToTokenIdx": permutedIdxToTokenIdx.to(originalDevice), + "topKIndices": topKIndices.to(originalDevice), + } + + +def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor): + scores = F.sigmoid(logits) + scores_with_bias = scores + bias + if n_group > 1: + scores_shape = list(scores_with_bias.shape) + group_scores = torch.sum( + torch.topk( + scores_with_bias.view( + scores_shape[:-1] + [n_group, scores_shape[-1] // n_group] + ), + k=2, + dim=-1, + largest=True, + sorted=True, + )[0], + dim=-1, + ) + _, group_idx = torch.topk( + group_scores, k=topk_group, dim=-1, largest=True, sorted=True + ) + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(-1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]) + .reshape(scores_shape) + ) + scores_with_bias = scores_with_bias * score_mask + + _, topk_idx = torch.topk( + scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True + ) + new_mask = torch.zeros_like(scores) + new_mask.scatter_(-1, topk_idx, 1) + scores = scores * new_mask + score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 + scores = scores / score_sum * routed_scaling_factor + return scores + + +# Tiered TopK routing used by DeepSeek +def routing_reference_no_aux( + expert_logits, + routing_bias, + top_k, + n_groups, + top_k_groups, + routed_scaling, + padding, + use_routing_scales_on_input=False, +): + routing_logits = expert_logits.to(dtype=torch.float, device="cuda") + if use_routing_scales_on_input: + # if using routing scales on input, topK == 1 and the score is a plain sigmoid + scores = F.sigmoid(routing_logits) + else: + scores = noaux_tc_ref( + routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling + ) + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +# TopK -> Softmax +def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): + topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) + topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +# Softmax->TopK -> Normalize +def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding): + norm_topk_prob = True + scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1) + topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1) + + if norm_topk_prob: # only diff with mixtral sparse moe block! + topk_values /= topk_values.sum(dim=-1, keepdim=True) + topk_values = topk_values.to(expert_logits.dtype) + scores = scores.to(expert_logits.dtype) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +def run_moe_dequant(args, quant_mode=["fp4"]): + # Permute + total_num_padded_tokens = args.permute_info["permutedBufferSize"] + expanded_idx_to_permuted_idx = args.permute_info[ + "expandedTokenIdxToPermutedIdx" + ].cpu() + num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu() + permute_output = torch.full( + (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + for i in range(args.num_tokens): + for j in range(args.top_k): + permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j] + permute_output[permuted_idx] = args.hidden_states[i] + # Gemm1 + gemm1_output = torch.full( + (total_num_padded_tokens, 2 * args.intermediate_size), + float("nan"), + device="cuda", + ).to(torch.float) + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = permute_output[i : i + my_num_tokens] + my_b = args.gemm1_weights[expert_idx] + my_c = my_a @ my_b.t() + gemm1_output[i : i + my_num_tokens] = my_c + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + + if args.use_routing_scales_on_input: + assert args.top_k == 1 + # For each token and its top_k experts + for token_idx in range(args.num_tokens): + for k in range(args.top_k): + # Get the permuted index for this token's k-th expert + expanded_idx = token_idx * args.top_k + k + permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] + expert_weight = args.permute_info["topKLogits"].to(torch.float) + # Get the expert weight for this token and expert + weight = expert_weight[token_idx, k] + # Scale the corresponding row in gemm1_output + gemm1_output[permuted_idx] *= weight + + # Activation + activation_output = torch.full( + (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" + ).to(torch.float) + + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = gemm1_output[i : i + my_num_tokens] + my_x1 = my_a[:, : args.intermediate_size] + my_x2 = my_a[:, args.intermediate_size :] + activation_output[i : i + my_num_tokens] = F.silu(my_x2) * my_x1 + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + + if quant_mode == "fp4": + activation_output, c_global_sf = quant_dequant_fp4( + activation_output.to(torch.bfloat16), False, True + ) + activation_output = activation_output.to(torch.float) + args.c_global_sf = c_global_sf + + # Gemm2 + gemm2_output = torch.full( + (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = activation_output[i : i + my_num_tokens] + my_b = args.gemm2_weights[expert_idx] + my_c = my_a @ my_b.t() + gemm2_output[i : i + my_num_tokens] = my_c + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + # Finalize + expert_weight = args.permute_info["topKLogits"].to(torch.float) + finalize_output = torch.full( + (args.num_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + for i in range(args.num_tokens): + acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda") + for top_k_idx in range(args.top_k): + expanded_idx = i * args.top_k + top_k_idx + permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] + original_vector = gemm2_output[permuted_idx] + weight = ( + expert_weight[i, top_k_idx] + if not args.use_routing_scales_on_input + else 1.0 + ) + acc += original_vector * weight + finalize_output[i] = acc + return finalize_output + + +def e2m1_and_ufp8_scale_to_float_tensor_v2( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: torch.Tensor, + sf_vec_size, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +): + float_tensor = e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor.cpu(), + ufp8_scale_tensor.cpu().reshape(-1), + global_scale_tensor.cpu(), + sf_vec_size, + ufp8_type, + is_sf_swizzled_layout, + ) + return float_tensor + + +def e2m1_and_ufp8_scale_batches( + mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor, + global_scale_tensor: torch.Tensor, + sf_vec_size: int, + ufp8_type: int = 1, +): + num_batches = mat_fp4.size(0) + + scale_tensor = scale_tensor.view(num_batches, -1) + + tensors = [ + e2m1_and_ufp8_scale_to_float_tensor_v2( + mat_fp4[b, :, :], scale_tensor[b, :], global_scale_tensor[b], sf_vec_size + ) + for b in range(num_batches) + ] + + result = torch.stack(tensors) + + return result + + +def run_moe_reference_fp4(args): + sf_vec_size = 16 + + hidden_states_dequant = e2m1_and_ufp8_scale_to_float_tensor_v2( + args.hidden_states, + args.hidden_states_scale, + 1 / args.hidden_states_scale_global, + sf_vec_size, + ).cuda() + + gemm1_weights_dequant = e2m1_and_ufp8_scale_batches( + args.gemm1_weights, args.gemm1_scales, 1 / args.gemm1_scales_global, sf_vec_size + ).cuda() + + gemm2_weights_dequant = e2m1_and_ufp8_scale_batches( + args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global, sf_vec_size + ).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + ) + + return run_moe_dequant(args_dequant, "fp4"), args_dequant + + +def quant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): + a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() + sf_vec_size = 16 + + a_fp4, a_sf = fp4_quantize( + a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout + ) + + return a_fp4, a_sf, a_global_sf + + +def quant_fp4_with_global_sf( + a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True +): + """ + Quantize FP4 with pre-calculated global scale factor. + Used specifically for hidden states in CUDA graph capture to avoid runtime computation. + """ + sf_vec_size = 16 + + a_fp4, a_sf = fp4_quantize( + a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout + ) + + return a_fp4, a_sf, a_global_sf + + +def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=True): + quant_a = [] + sfs = [] + global_sfs = [] + for i in range(num_experts): + a_fp4, a_sf, a_global_sf = quant_fp4(a[i], use_ue8m0, is_sf_swizzled_layout) + quant_a.append(a_fp4) + sfs.append(a_sf) + global_sfs.append(a_global_sf) + + result_quant_a = torch.stack(quant_a) + result_sfs = torch.stack(sfs) + result_global_sfs = torch.stack(global_sfs) + + return result_quant_a, result_sfs, result_global_sfs + + +def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): + a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() + sf_vec_size = 16 + + a_fp4, a_sf = fp4_quantize( + a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout + ) + + a_pt = e2m1_and_ufp8_scale_to_float_tensor_v2( + a_fp4.cpu(), a_sf.cpu(), 1 / a_global_sf, sf_vec_size + ) + + return a_pt.cuda(), a_global_sf + + +def check_accuracy(a, b, atol, rtol, percent): + if torch.any(torch.isnan(a)): + raise Exception("NaN in a") + if torch.any(torch.isnan(b)): + raise Exception("NaN in b") + if torch.any(torch.isinf(a)): + raise Exception("Inf in a") + if torch.any(torch.isinf(b)): + raise Exception("Inf in b") + assert a.shape == b.shape + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + "Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol) + ) + + +def create_expert_logits(num_token, num_experts, k): + """ + Create deterministic expert logits for testing where specific experts + are guaranteed to be selected for each token. + + Args: + num_token: Number of tokens + num_experts: Number of experts + k: Top-k value (number of experts to select per token) + + Returns: + logits: Expert logits tensor [num_token, num_experts] (CUDA bfloat16) + index: Expected top-k indices [num_token, k] (CUDA) + large_random: The large random values used [num_token, k] (CUDA) + """ + # 1. Create logits tensor + logits = torch.zeros(num_token, num_experts) + + # 2. Set index sequence + final_size = num_token * k + repeat_count = math.ceil(final_size / num_experts) + indices = torch.arange(num_experts, dtype=torch.int32) + indices = indices.repeat(repeat_count) + indices = indices[:final_size] + index = indices.view(num_token, k).contiguous() + + # 3. Generate large random numbers + large_random = torch.randint(5, 11, (num_token, k), dtype=torch.float32) + + # 4. Put the random number to the place we want + for token_id in range(num_token): + for j in range(k): + expert_idx = index[token_id, j] + logits[token_id, expert_idx] = large_random[token_id, j] + + # 5. Set smaller random numbers in other places + mask = logits == 0 + logits[mask] = torch.rand(mask.sum()) + + logits = torch.nn.functional.softmax(logits, dim=-1) + + # Convert to CUDA tensors with appropriate dtypes + logits = logits.to(device="cuda", dtype=torch.bfloat16) + index = index.to(device="cuda") + large_random = large_random.to(device="cuda") + + return logits, index, large_random + + +def compute_moe_reference_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, +): + """ + Compute the reference MoE output using dequantized operations with full routing support. + + Returns: + output_dequant_reference: Reference output tensor + args_dequant: Dequantized arguments for debugging + """ + use_ue8m0 = False + + # Quantize hidden states + ( + hidden_states_fp4_bytes, + hidden_states_scale_fp4_bytes, + hidden_states_scale_global, + ) = quant_fp4(hidden_states, use_ue8m0, True) + + # Quantize the weights for FC1 + gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( + quant_fp4_batches(gemm1_weights, num_experts, use_ue8m0, True) + ) + + # Quantize the weights for FC2 + gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( + quant_fp4_batches(gemm2_weights, num_experts, use_ue8m0, True) + ) + + # Generate routing info based on method + if routing_method_type == RoutingMethodType.DeepSeekV3: + permute_info, scores = routing_reference_no_aux( + expert_logits, + routing_bias, + top_k, + n_groups, + top_k_groups, + routed_scaling, + padding, + ) + elif routing_method_type == RoutingMethodType.Renormalize: + permute_info, scores = routing_reference_renormalize( + expert_logits, top_k, num_experts, padding + ) + elif routing_method_type == RoutingMethodType.RenormalizeNaive: + permute_info, scores = routing_reference_renormalize_naive( + expert_logits, top_k, num_experts, padding + ) + else: + raise NotImplementedError( + f"Routing method {routing_method_type} not implemented" + ) + + # Create arguments for reference computation + args = moe_args( + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states_fp4_bytes, + hidden_states_scale_fp4_bytes, + hidden_states_scale_global, + scores, + gemm1_weights_fp4_bytes, + gemm1_scales_fp4_bytes, + gemm1_scales_global, + gemm2_weights_fp4_bytes, + gemm2_scales_fp4_bytes, + gemm2_scales_global, + permute_info, + False, + ) + + # Run the reference implementation + output_dequant_reference, args_dequant = run_moe_reference_fp4(args) + + return output_dequant_reference, args_dequant, args + + +def compute_moe_actual_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, + tile_tokens_dim, + args_dequant, + args, +): + """ + Compute the actual MoE output using the optimized kernel with full routing support. + + Returns: + output_dequant_actual: Actual output tensor from the kernel + """ + + def prepare_static_weights(): + """ + Handle all static weight-related preprocessing. + This should be done once at model load time in production. + + Returns: + Dict containing all preprocessed weight tensors and scale factors + """ + use_ue8m0 = False + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + + # Quantize weights with linear layout for kernels + _, gemm1_scales_linear_fp4_bytes, _ = quant_fp4_batches( + gemm1_weights, num_experts, use_ue8m0, False + ) + _, gemm2_scales_linear_fp4_bytes, _ = quant_fp4_batches( + gemm2_weights, num_experts, use_ue8m0, False + ) + + # Convert quantized weights to proper formats + gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2 + ) # packed fp4 + gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape( + num_experts, 2 * intermediate_size, hidden_size // 16 + ) # fp8 scaling factors + + gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 2 + ) # packed fp4 + gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape( + num_experts, hidden_size, intermediate_size // 16 + ) # fp8 scaling factors + + # Reorder rows of W1 and scales for fused gated activation + gemm1_weights_fp4_interleaved = [] + gemm1_scales_fp4_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()) + ) + gemm1_scales_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone()) + ) + + # Stack weights and scales for all experts + gemm1_weights_fp4_interleaved = torch.stack( + gemm1_weights_fp4_interleaved + ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2) + gemm1_scales_fp4_interleaved = torch.stack( + gemm1_scales_fp4_interleaved + ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp4_shuffled = [] + gemm1_scales_fp4_shuffled = [] + gemm2_weights_fp4_shuffled = [] + gemm2_scales_fp4_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm1_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) + + gemm2_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm2_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m + ) + ) + + # Stack weights for all experts + gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_scales_fp4_shuffled = ( + torch.stack(gemm1_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + ) + + gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_scales_fp4_shuffled = ( + torch.stack(gemm2_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, hidden_size, intermediate_size // 16) + ) + + # Calculate scaling factors that depend on weights + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) + scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( + 1.0 / args.hidden_states_scale_global + ) + scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * ( + 1.0 / args.gemm2_scales_global + ) + + return { + "gemm1_weights_fp4_shuffled": gemm1_weights_fp4_shuffled, + "gemm1_scales_fp4_shuffled": gemm1_scales_fp4_shuffled, + "gemm2_weights_fp4_shuffled": gemm2_weights_fp4_shuffled, + "gemm2_scales_fp4_shuffled": gemm2_scales_fp4_shuffled, + "scale_c_fc1": scale_c_fc1, + "scale_gate_fc1": scale_gate_fc1, + "scale_c_fc2": scale_c_fc2, + } + + # Process static weights (would be cached in production) + static_data = prepare_static_weights() + + # Calculate global scale factor for hidden states offline (precalculated parameter) + hidden_states_global_sf = (448 * 6) / hidden_states.float().abs().nan_to_num().max() + + use_ue8m0 = False + do_finalize = True + + # Quantize hidden states with precalculated global scale + hidden_states_fp4_bytes, hidden_states_scale_linear_fp4_bytes, _ = ( + quant_fp4_with_global_sf( + hidden_states, hidden_states_global_sf, use_ue8m0, False + ) + ) + hidden_states_fp4 = hidden_states_fp4_bytes.reshape(num_tokens, hidden_size // 2) + hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape(-1) + + output = trtllm_fp4_block_scale_moe( + expert_logits, + routing_bias, + hidden_states_fp4, + hidden_states_scale_linear_fp4, + static_data["gemm1_weights_fp4_shuffled"], + static_data["gemm1_scales_fp4_shuffled"], + static_data["gemm2_weights_fp4_shuffled"], + static_data["gemm2_scales_fp4_shuffled"], + static_data["scale_c_fc1"], + static_data["scale_gate_fc1"], + static_data["scale_c_fc2"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + tile_tokens_dim, + routing_method_type, + do_finalize=True, + ) + + output_dequant_actual = output[0].to(torch.float) + + return output_dequant_actual + + +def compare_moe_outputs( + output_dequant_reference, + output_dequant_actual, + seed, + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + routing_method_type, +): + """ + Compare reference and actual MoE outputs and perform accuracy analysis. + + Raises: + Exception: If accuracy test fails + """ + # Use check_accuracy to validate - it will raise exception if test fails + check_accuracy( + output_dequant_reference, + output_dequant_actual, + atol=0.1, + rtol=0.85, + percent=0.925, + ) + + +@pytest.mark.parametrize("num_tokens", [1, 1024, 4096]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192]) +@pytest.mark.parametrize( + "routing_info", + [ + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + id="RoutingDSv3", + ), + pytest.param( + { + "num_experts": 72, + "top_k": 6, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + id="RoutingDSlite", + ), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + }, + id="RoutingRenormalize", + ), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.RenormalizeNaive, + }, + id="RoutingRenormalizeNaive", + ), + ], +) +def test_moe_nvfp4( + num_tokens, + hidden_size, + intermediate_size, + routing_info, +): + seed = 42 + torch.random.manual_seed(seed) + + # Extract routing configuration + top_k = routing_info["top_k"] + padding = routing_info["padding"] + n_groups = routing_info["n_groups"] + top_k_groups = routing_info["top_k_groups"] + routed_scaling = routing_info["routed_scaling"] + num_experts = routing_info["num_experts"] + routing_method_type = routing_info["routing_method_type"] + tile_tokens_dim = 8 + + # Validation checks + assert top_k <= num_experts + assert top_k <= 8 + if (top_k_groups is not None) and (n_groups is not None): + assert top_k_groups <= 4 + assert num_experts > n_groups + assert num_experts % n_groups == 0 + assert num_experts % 4 == 0 + assert top_k < (top_k_groups * num_experts / n_groups) + + # Create expert logits based on routing method + if routing_method_type == RoutingMethodType.DeepSeekV3: + expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( + torch.float + ) + elif ( + routing_method_type == RoutingMethodType.RenormalizeNaive + or routing_method_type == RoutingMethodType.Renormalize + ): + expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( + torch.bfloat16 + ) + + # Handle routing bias + if routing_info["has_routing_bias"]: + routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + else: + routing_bias = None + + hidden_states = 2 * torch.randn( + (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 + ) + gemm1_weights = torch.randn( + (num_experts, 2 * intermediate_size, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ) + gemm2_weights = torch.randn( + (num_experts, hidden_size, intermediate_size), + device="cuda", + dtype=torch.bfloat16, + ) + + # Compute reference output with updated routing method handling + output_dequant_reference, args_dequant, args = compute_moe_reference_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, + ) + + # Compute actual output using optimized kernel + output_dequant_actual = compute_moe_actual_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, + tile_tokens_dim, + args_dequant, + args, + ) + + # Compare outputs - will raise exception if test fails + compare_moe_outputs( + output_dequant_reference, + output_dequant_actual, + seed, + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + routing_method_type, + ) From 3d37611c56390562c4a6b63ed8c6f39d507660e0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 13:58:54 +0800 Subject: [PATCH 22/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py.py | 621 -------------------- 1 file changed, 621 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py.py b/benchmarks/bench_trtllm_gen_fused_moe.py.py index b5279cf84..3dd346622 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py.py @@ -31,348 +31,6 @@ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe -class moe_args: - - def __init__( - self, - num_tokens, - num_experts, - hidden_size, - intermediate_size, - top_k, - padding, - hidden_states, - hidden_states_scale, - hidden_states_scale_global, - expert_logits, - gemm1_weights, - gemm1_scales, - gemm1_scales_global, - gemm2_weights, - gemm2_scales, - gemm2_scales_global, - permute_info, - use_routing_scales_on_input, - ): - self.num_tokens = num_tokens - self.num_experts = num_experts - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.top_k = top_k - self.padding = padding - self.hidden_states = hidden_states - self.hidden_states_scale = hidden_states_scale - self.hidden_states_scale_global = hidden_states_scale_global - self.expert_logits = expert_logits - self.gemm1_weights = gemm1_weights - self.gemm1_scales = gemm1_scales - self.gemm1_scales_global = gemm1_scales_global - self.gemm2_weights = gemm2_weights - self.gemm2_scales = gemm2_scales - self.gemm2_scales_global = gemm2_scales_global - self.permute_info = permute_info - self.use_routing_scales_on_input = use_routing_scales_on_input - - -class moe_args_dequant: - - def __init__( - self, - num_tokens, - num_experts, - hidden_size, - intermediate_size, - top_k, - padding, - hidden_states, - expert_logits, - gemm1_weights, - gemm2_weights, - permute_info, - use_routing_scales_on_input, - ): - self.num_tokens = num_tokens - self.num_experts = num_experts - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.top_k = top_k - self.padding = padding - self.hidden_states = hidden_states - self.expert_logits = expert_logits - self.gemm1_weights = gemm1_weights - self.gemm2_weights = gemm2_weights - self.permute_info = permute_info - self.use_routing_scales_on_input = use_routing_scales_on_input - - -def routing_reference(expertLogits, topK, padding): - originalDevice = expertLogits.device - expertLogits = expertLogits.cpu() - numTokens, numExperts = expertLogits.shape - assert topK <= numExperts - - numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64) - expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64) - expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64) - - topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1) - for tokenIdx in range(numTokens): - for k in range(topK): - expandedIdx = tokenIdx * topK + k - expertIndex = topKIndices[tokenIdx, k] - expandedTokenIdxToExpert[expandedIdx] = expertIndex - expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex] - numTokensPerExpert[expertIndex] += 1 - - paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64) - for ii in range(numExperts): - - def divUpMul(a, b): - return (a + b - 1) // b * b - - paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[ - ii - ] + divUpMul(numTokensPerExpert[ii], padding) - permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts] - - expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64) - permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) - permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) - for tokenIdx in range(numTokens): - for k in range(topK): - expandedIdx = tokenIdx * topK + k - expert = expandedTokenIdxToExpert[expandedIdx] - offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx] - offsetForExpert = paddedTokensPerExpertPrefixSum[expert] - permutedIdx = offsetForExpert + offsetWithinExpert - - expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx - permutedIdxToExpandedIdx[permutedIdx] = expandedIdx - permutedIdxToTokenIdx[permutedIdx] = tokenIdx - return { - "paddedTokensPerExpertPrefixSum": paddedTokensPerExpertPrefixSum.to( - originalDevice - ), - "permutedBufferSize": permutedBufferSize.item(), - "expandedTokenIdxToPermutedIdx": expandedTokenIdxToPermutedIdx.to( - originalDevice - ), - "permutedIdxToExpandedIdx": permutedIdxToExpandedIdx.to(originalDevice), - "numTokensPerExpert": numTokensPerExpert.to(originalDevice), - "expandedTokenIdxToExpert": expandedTokenIdxToExpert.to(originalDevice), - "topKLogits": topKLogits.to(originalDevice), - "permutedIdxToTokenIdx": permutedIdxToTokenIdx.to(originalDevice), - "topKIndices": topKIndices.to(originalDevice), - } - - -def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor): - scores = F.sigmoid(logits) - scores_with_bias = scores + bias - if n_group > 1: - scores_shape = list(scores_with_bias.shape) - group_scores = torch.sum( - torch.topk( - scores_with_bias.view( - scores_shape[:-1] + [n_group, scores_shape[-1] // n_group] - ), - k=2, - dim=-1, - largest=True, - sorted=True, - )[0], - dim=-1, - ) - _, group_idx = torch.topk( - group_scores, k=topk_group, dim=-1, largest=True, sorted=True - ) - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(-1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]) - .reshape(scores_shape) - ) - scores_with_bias = scores_with_bias * score_mask - - _, topk_idx = torch.topk( - scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True - ) - new_mask = torch.zeros_like(scores) - new_mask.scatter_(-1, topk_idx, 1) - scores = scores * new_mask - score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 - scores = scores / score_sum * routed_scaling_factor - return scores - - -# Tiered TopK routing used by DeepSeek -def routing_reference_no_aux( - expert_logits, - routing_bias, - top_k, - n_groups, - top_k_groups, - routed_scaling, - padding, - use_routing_scales_on_input=False, -): - routing_logits = expert_logits.to(dtype=torch.float, device="cuda") - if use_routing_scales_on_input: - # if using routing scales on input, topK == 1 and the score is a plain sigmoid - scores = F.sigmoid(routing_logits) - else: - scores = noaux_tc_ref( - routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling - ) - permute_info = routing_reference(scores, top_k, padding) - return permute_info, scores - - -# TopK -> Softmax -def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): - topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) - topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1) - - new_mask = torch.zeros_like(expert_logits) - new_mask.scatter_(-1, topk_idx, 1) - scores = expert_logits * new_mask - - for i in range(topk_idx.shape[0]): - for j in range(topk_idx.shape[1]): - scores[i, topk_idx[i, j]] = topk_values[i, j] - permute_info = routing_reference(scores, top_k, padding) - return permute_info, scores - - -# Softmax->TopK -> Normalize -def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding): - norm_topk_prob = True - scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1) - topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1) - - if norm_topk_prob: # only diff with mixtral sparse moe block! - topk_values /= topk_values.sum(dim=-1, keepdim=True) - topk_values = topk_values.to(expert_logits.dtype) - scores = scores.to(expert_logits.dtype) - - new_mask = torch.zeros_like(expert_logits) - new_mask.scatter_(-1, topk_idx, 1) - scores = expert_logits * new_mask - - for i in range(topk_idx.shape[0]): - for j in range(topk_idx.shape[1]): - scores[i, topk_idx[i, j]] = topk_values[i, j] - permute_info = routing_reference(scores, top_k, padding) - return permute_info, scores - - -def run_moe_dequant(args, quant_mode=["fp4"]): - # Permute - total_num_padded_tokens = args.permute_info["permutedBufferSize"] - expanded_idx_to_permuted_idx = args.permute_info[ - "expandedTokenIdxToPermutedIdx" - ].cpu() - num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu() - permute_output = torch.full( - (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) - for i in range(args.num_tokens): - for j in range(args.top_k): - permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j] - permute_output[permuted_idx] = args.hidden_states[i] - # Gemm1 - gemm1_output = torch.full( - (total_num_padded_tokens, 2 * args.intermediate_size), - float("nan"), - device="cuda", - ).to(torch.float) - i = 0 - for expert_idx in range(args.num_experts): - my_num_tokens = num_tokens_per_expert[expert_idx] - if my_num_tokens == 0: - continue - my_a = permute_output[i : i + my_num_tokens] - my_b = args.gemm1_weights[expert_idx] - my_c = my_a @ my_b.t() - gemm1_output[i : i + my_num_tokens] = my_c - i += my_num_tokens - i = (i + args.padding - 1) // args.padding * args.padding - - if args.use_routing_scales_on_input: - assert args.top_k == 1 - # For each token and its top_k experts - for token_idx in range(args.num_tokens): - for k in range(args.top_k): - # Get the permuted index for this token's k-th expert - expanded_idx = token_idx * args.top_k + k - permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] - expert_weight = args.permute_info["topKLogits"].to(torch.float) - # Get the expert weight for this token and expert - weight = expert_weight[token_idx, k] - # Scale the corresponding row in gemm1_output - gemm1_output[permuted_idx] *= weight - - # Activation - activation_output = torch.full( - (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" - ).to(torch.float) - - i = 0 - for expert_idx in range(args.num_experts): - my_num_tokens = num_tokens_per_expert[expert_idx] - if my_num_tokens == 0: - continue - my_a = gemm1_output[i : i + my_num_tokens] - my_x1 = my_a[:, : args.intermediate_size] - my_x2 = my_a[:, args.intermediate_size :] - activation_output[i : i + my_num_tokens] = F.silu(my_x2) * my_x1 - i += my_num_tokens - i = (i + args.padding - 1) // args.padding * args.padding - - if quant_mode == "fp4": - activation_output, c_global_sf = quant_dequant_fp4( - activation_output.to(torch.bfloat16), False, True - ) - activation_output = activation_output.to(torch.float) - args.c_global_sf = c_global_sf - - # Gemm2 - gemm2_output = torch.full( - (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) - i = 0 - for expert_idx in range(args.num_experts): - my_num_tokens = num_tokens_per_expert[expert_idx] - if my_num_tokens == 0: - continue - my_a = activation_output[i : i + my_num_tokens] - my_b = args.gemm2_weights[expert_idx] - my_c = my_a @ my_b.t() - gemm2_output[i : i + my_num_tokens] = my_c - i += my_num_tokens - i = (i + args.padding - 1) // args.padding * args.padding - # Finalize - expert_weight = args.permute_info["topKLogits"].to(torch.float) - finalize_output = torch.full( - (args.num_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) - for i in range(args.num_tokens): - acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda") - for top_k_idx in range(args.top_k): - expanded_idx = i * args.top_k + top_k_idx - permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] - original_vector = gemm2_output[permuted_idx] - weight = ( - expert_weight[i, top_k_idx] - if not args.use_routing_scales_on_input - else 1.0 - ) - acc += original_vector * weight - finalize_output[i] = acc - return finalize_output - - def e2m1_and_ufp8_scale_to_float_tensor_v2( e2m1_tensor: torch.Tensor, ufp8_scale_tensor: torch.Tensor, @@ -392,63 +50,6 @@ def e2m1_and_ufp8_scale_to_float_tensor_v2( return float_tensor -def e2m1_and_ufp8_scale_batches( - mat_fp4: torch.Tensor, - scale_tensor: torch.Tensor, - global_scale_tensor: torch.Tensor, - sf_vec_size: int, - ufp8_type: int = 1, -): - num_batches = mat_fp4.size(0) - - scale_tensor = scale_tensor.view(num_batches, -1) - - tensors = [ - e2m1_and_ufp8_scale_to_float_tensor_v2( - mat_fp4[b, :, :], scale_tensor[b, :], global_scale_tensor[b], sf_vec_size - ) - for b in range(num_batches) - ] - - result = torch.stack(tensors) - - return result - - -def run_moe_reference_fp4(args): - sf_vec_size = 16 - - hidden_states_dequant = e2m1_and_ufp8_scale_to_float_tensor_v2( - args.hidden_states, - args.hidden_states_scale, - 1 / args.hidden_states_scale_global, - sf_vec_size, - ).cuda() - - gemm1_weights_dequant = e2m1_and_ufp8_scale_batches( - args.gemm1_weights, args.gemm1_scales, 1 / args.gemm1_scales_global, sf_vec_size - ).cuda() - - gemm2_weights_dequant = e2m1_and_ufp8_scale_batches( - args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global, sf_vec_size - ).cuda() - - args_dequant = moe_args_dequant( - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.padding, - hidden_states_dequant, - args.expert_logits, - gemm1_weights_dequant, - gemm2_weights_dequant, - args.permute_info, - args.use_routing_scales_on_input, - ) - - return run_moe_dequant(args_dequant, "fp4"), args_dequant def quant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): @@ -510,169 +111,6 @@ def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): return a_pt.cuda(), a_global_sf -def check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): - raise Exception("NaN in a") - if torch.any(torch.isnan(b)): - raise Exception("NaN in b") - if torch.any(torch.isinf(a)): - raise Exception("Inf in a") - if torch.any(torch.isinf(b)): - raise Exception("Inf in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: - raise Exception( - "Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol) - ) - - -def create_expert_logits(num_token, num_experts, k): - """ - Create deterministic expert logits for testing where specific experts - are guaranteed to be selected for each token. - - Args: - num_token: Number of tokens - num_experts: Number of experts - k: Top-k value (number of experts to select per token) - - Returns: - logits: Expert logits tensor [num_token, num_experts] (CUDA bfloat16) - index: Expected top-k indices [num_token, k] (CUDA) - large_random: The large random values used [num_token, k] (CUDA) - """ - # 1. Create logits tensor - logits = torch.zeros(num_token, num_experts) - - # 2. Set index sequence - final_size = num_token * k - repeat_count = math.ceil(final_size / num_experts) - indices = torch.arange(num_experts, dtype=torch.int32) - indices = indices.repeat(repeat_count) - indices = indices[:final_size] - index = indices.view(num_token, k).contiguous() - - # 3. Generate large random numbers - large_random = torch.randint(5, 11, (num_token, k), dtype=torch.float32) - - # 4. Put the random number to the place we want - for token_id in range(num_token): - for j in range(k): - expert_idx = index[token_id, j] - logits[token_id, expert_idx] = large_random[token_id, j] - - # 5. Set smaller random numbers in other places - mask = logits == 0 - logits[mask] = torch.rand(mask.sum()) - - logits = torch.nn.functional.softmax(logits, dim=-1) - - # Convert to CUDA tensors with appropriate dtypes - logits = logits.to(device="cuda", dtype=torch.bfloat16) - index = index.to(device="cuda") - large_random = large_random.to(device="cuda") - - return logits, index, large_random - - -def compute_moe_reference_with_routing( - num_tokens, - hidden_size, - intermediate_size, - num_experts, - expert_logits, - routing_bias, - hidden_states, - gemm1_weights, - gemm2_weights, - top_k, - padding, - n_groups, - top_k_groups, - routed_scaling, - routing_method_type, -): - """ - Compute the reference MoE output using dequantized operations with full routing support. - - Returns: - output_dequant_reference: Reference output tensor - args_dequant: Dequantized arguments for debugging - """ - use_ue8m0 = False - - # Quantize hidden states - ( - hidden_states_fp4_bytes, - hidden_states_scale_fp4_bytes, - hidden_states_scale_global, - ) = quant_fp4(hidden_states, use_ue8m0, True) - - # Quantize the weights for FC1 - gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( - quant_fp4_batches(gemm1_weights, num_experts, use_ue8m0, True) - ) - - # Quantize the weights for FC2 - gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( - quant_fp4_batches(gemm2_weights, num_experts, use_ue8m0, True) - ) - - # Generate routing info based on method - if routing_method_type == RoutingMethodType.DeepSeekV3: - permute_info, scores = routing_reference_no_aux( - expert_logits, - routing_bias, - top_k, - n_groups, - top_k_groups, - routed_scaling, - padding, - ) - elif routing_method_type == RoutingMethodType.Renormalize: - permute_info, scores = routing_reference_renormalize( - expert_logits, top_k, num_experts, padding - ) - elif routing_method_type == RoutingMethodType.RenormalizeNaive: - permute_info, scores = routing_reference_renormalize_naive( - expert_logits, top_k, num_experts, padding - ) - else: - raise NotImplementedError( - f"Routing method {routing_method_type} not implemented" - ) - - # Create arguments for reference computation - args = moe_args( - num_tokens, - num_experts, - hidden_size, - intermediate_size, - top_k, - padding, - hidden_states_fp4_bytes, - hidden_states_scale_fp4_bytes, - hidden_states_scale_global, - scores, - gemm1_weights_fp4_bytes, - gemm1_scales_fp4_bytes, - gemm1_scales_global, - gemm2_weights_fp4_bytes, - gemm2_scales_fp4_bytes, - gemm2_scales_global, - permute_info, - False, - ) - - # Run the reference implementation - output_dequant_reference, args_dequant = run_moe_reference_fp4(args) - - return output_dequant_reference, args_dequant, args - def compute_moe_actual_with_routing( num_tokens, @@ -874,33 +312,6 @@ def prepare_static_weights(): return output_dequant_actual -def compare_moe_outputs( - output_dequant_reference, - output_dequant_actual, - seed, - num_tokens, - hidden_size, - intermediate_size, - num_experts, - top_k, - routing_method_type, -): - """ - Compare reference and actual MoE outputs and perform accuracy analysis. - - Raises: - Exception: If accuracy test fails - """ - # Use check_accuracy to validate - it will raise exception if test fails - check_accuracy( - output_dequant_reference, - output_dequant_actual, - atol=0.1, - rtol=0.85, - percent=0.925, - ) - - @pytest.mark.parametrize("num_tokens", [1, 1024, 4096]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192]) @@ -1023,25 +434,6 @@ def test_moe_nvfp4( dtype=torch.bfloat16, ) - # Compute reference output with updated routing method handling - output_dequant_reference, args_dequant, args = compute_moe_reference_with_routing( - num_tokens, - hidden_size, - intermediate_size, - num_experts, - expert_logits, - routing_bias, - hidden_states, - gemm1_weights, - gemm2_weights, - top_k, - padding, - n_groups, - top_k_groups, - routed_scaling, - routing_method_type, - ) - # Compute actual output using optimized kernel output_dequant_actual = compute_moe_actual_with_routing( num_tokens, @@ -1063,16 +455,3 @@ def test_moe_nvfp4( args_dequant, args, ) - - # Compare outputs - will raise exception if test fails - compare_moe_outputs( - output_dequant_reference, - output_dequant_actual, - seed, - num_tokens, - hidden_size, - intermediate_size, - num_experts, - top_k, - routing_method_type, - ) From 62eaa46468ec69ffc971f17ca043a441913505c4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:01:28 +0800 Subject: [PATCH 23/28] more --- ...h_trtllm_gen_fused_moe.py.py => bench_trtllm_gen_fused_moe.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmarks/{bench_trtllm_gen_fused_moe.py.py => bench_trtllm_gen_fused_moe.py} (100%) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py.py b/benchmarks/bench_trtllm_gen_fused_moe.py similarity index 100% rename from benchmarks/bench_trtllm_gen_fused_moe.py.py rename to benchmarks/bench_trtllm_gen_fused_moe.py From dbecb4a1d8a4c08ad36286076d3f3fbb8a69055f Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:09:42 +0800 Subject: [PATCH 24/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py | 128 ++++++++++++----------- 1 file changed, 67 insertions(+), 61 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py b/benchmarks/bench_trtllm_gen_fused_moe.py index 3dd346622..ab94d65f3 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py @@ -31,6 +31,62 @@ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +BATCH_SIZES = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, +] + +test_configs = [ + # NOTE MODIFIED ADD + *[ + { + "hidden_size": 7168, + "intermediate_size": 2048, + # RoutingDSv3 + "routing_info": { + # TODO correct? + "num_experts": num_experts, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + } + for num_experts in [ + 288 // 1, + 288 // 2, + 288 // 4, + 288 // 8, + 288 // 16, + 288 // 32, + # TODO support + # 288 // 48, + # 288 // 72, + ] + ], +] + def e2m1_and_ufp8_scale_to_float_tensor_v2( e2m1_tensor: torch.Tensor, ufp8_scale_tensor: torch.Tensor, @@ -312,67 +368,7 @@ def prepare_static_weights(): return output_dequant_actual -@pytest.mark.parametrize("num_tokens", [1, 1024, 4096]) -@pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192]) -@pytest.mark.parametrize( - "routing_info", - [ - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - }, - id="RoutingDSv3", - ), - pytest.param( - { - "num_experts": 72, - "top_k": 6, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - }, - id="RoutingDSlite", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - }, - id="RoutingRenormalize", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.RenormalizeNaive, - }, - id="RoutingRenormalizeNaive", - ), - ], -) -def test_moe_nvfp4( +def run_one( num_tokens, hidden_size, intermediate_size, @@ -455,3 +451,13 @@ def test_moe_nvfp4( args_dequant, args, ) + +if __name__ == '__main__': + for config in test_configs: + for batch_size in BATCH_SIZES: + run_one( + num_tokens=batch_size, + hidden_size=config["hidden_size"], + intermediate_size=config["intermediate_size"], + routing_info=config["routing_info"], + ) From 91863e4da8ed40bcde69c289578b4d269b5e462c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:10:28 +0800 Subject: [PATCH 25/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py | 745 +++++++++++++++++++++-- 1 file changed, 680 insertions(+), 65 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py b/benchmarks/bench_trtllm_gen_fused_moe.py index ab94d65f3..d25ca5b7f 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py @@ -31,61 +31,347 @@ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe -BATCH_SIZES = [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 384, - 512, - 768, - 1024, - 1536, - 2048, - 3072, - 4096, -] - -test_configs = [ - # NOTE MODIFIED ADD - *[ - { - "hidden_size": 7168, - "intermediate_size": 2048, - # RoutingDSv3 - "routing_info": { - # TODO correct? - "num_experts": num_experts, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - }, - } - for num_experts in [ - 288 // 1, - 288 // 2, - 288 // 4, - 288 // 8, - 288 // 16, - 288 // 32, - # TODO support - # 288 // 48, - # 288 // 72, - ] - ], -] +class moe_args: + + def __init__( + self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + hidden_states_scale, + hidden_states_scale_global, + expert_logits, + gemm1_weights, + gemm1_scales, + gemm1_scales_global, + gemm2_weights, + gemm2_scales, + gemm2_scales_global, + permute_info, + use_routing_scales_on_input, + ): + self.num_tokens = num_tokens + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.top_k = top_k + self.padding = padding + self.hidden_states = hidden_states + self.hidden_states_scale = hidden_states_scale + self.hidden_states_scale_global = hidden_states_scale_global + self.expert_logits = expert_logits + self.gemm1_weights = gemm1_weights + self.gemm1_scales = gemm1_scales + self.gemm1_scales_global = gemm1_scales_global + self.gemm2_weights = gemm2_weights + self.gemm2_scales = gemm2_scales + self.gemm2_scales_global = gemm2_scales_global + self.permute_info = permute_info + self.use_routing_scales_on_input = use_routing_scales_on_input + + +class moe_args_dequant: + + def __init__( + self, + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states, + expert_logits, + gemm1_weights, + gemm2_weights, + permute_info, + use_routing_scales_on_input, + ): + self.num_tokens = num_tokens + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.top_k = top_k + self.padding = padding + self.hidden_states = hidden_states + self.expert_logits = expert_logits + self.gemm1_weights = gemm1_weights + self.gemm2_weights = gemm2_weights + self.permute_info = permute_info + self.use_routing_scales_on_input = use_routing_scales_on_input + + +def routing_reference(expertLogits, topK, padding): + originalDevice = expertLogits.device + expertLogits = expertLogits.cpu() + numTokens, numExperts = expertLogits.shape + assert topK <= numExperts + + numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64) + expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64) + expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64) + + topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1) + for tokenIdx in range(numTokens): + for k in range(topK): + expandedIdx = tokenIdx * topK + k + expertIndex = topKIndices[tokenIdx, k] + expandedTokenIdxToExpert[expandedIdx] = expertIndex + expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex] + numTokensPerExpert[expertIndex] += 1 + + paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64) + for ii in range(numExperts): + + def divUpMul(a, b): + return (a + b - 1) // b * b + + paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[ + ii + ] + divUpMul(numTokensPerExpert[ii], padding) + permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts] + + expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64) + permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) + permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) + for tokenIdx in range(numTokens): + for k in range(topK): + expandedIdx = tokenIdx * topK + k + expert = expandedTokenIdxToExpert[expandedIdx] + offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx] + offsetForExpert = paddedTokensPerExpertPrefixSum[expert] + permutedIdx = offsetForExpert + offsetWithinExpert + + expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx + permutedIdxToExpandedIdx[permutedIdx] = expandedIdx + permutedIdxToTokenIdx[permutedIdx] = tokenIdx + return { + "paddedTokensPerExpertPrefixSum": paddedTokensPerExpertPrefixSum.to( + originalDevice + ), + "permutedBufferSize": permutedBufferSize.item(), + "expandedTokenIdxToPermutedIdx": expandedTokenIdxToPermutedIdx.to( + originalDevice + ), + "permutedIdxToExpandedIdx": permutedIdxToExpandedIdx.to(originalDevice), + "numTokensPerExpert": numTokensPerExpert.to(originalDevice), + "expandedTokenIdxToExpert": expandedTokenIdxToExpert.to(originalDevice), + "topKLogits": topKLogits.to(originalDevice), + "permutedIdxToTokenIdx": permutedIdxToTokenIdx.to(originalDevice), + "topKIndices": topKIndices.to(originalDevice), + } + + +def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor): + scores = F.sigmoid(logits) + scores_with_bias = scores + bias + if n_group > 1: + scores_shape = list(scores_with_bias.shape) + group_scores = torch.sum( + torch.topk( + scores_with_bias.view( + scores_shape[:-1] + [n_group, scores_shape[-1] // n_group] + ), + k=2, + dim=-1, + largest=True, + sorted=True, + )[0], + dim=-1, + ) + _, group_idx = torch.topk( + group_scores, k=topk_group, dim=-1, largest=True, sorted=True + ) + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(-1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]) + .reshape(scores_shape) + ) + scores_with_bias = scores_with_bias * score_mask + + _, topk_idx = torch.topk( + scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True + ) + new_mask = torch.zeros_like(scores) + new_mask.scatter_(-1, topk_idx, 1) + scores = scores * new_mask + score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 + scores = scores / score_sum * routed_scaling_factor + return scores + + +# Tiered TopK routing used by DeepSeek +def routing_reference_no_aux( + expert_logits, + routing_bias, + top_k, + n_groups, + top_k_groups, + routed_scaling, + padding, + use_routing_scales_on_input=False, +): + routing_logits = expert_logits.to(dtype=torch.float, device="cuda") + if use_routing_scales_on_input: + # if using routing scales on input, topK == 1 and the score is a plain sigmoid + scores = F.sigmoid(routing_logits) + else: + scores = noaux_tc_ref( + routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling + ) + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +# TopK -> Softmax +def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): + topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) + topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +# Softmax->TopK -> Normalize +def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding): + norm_topk_prob = True + scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1) + topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1) + + if norm_topk_prob: # only diff with mixtral sparse moe block! + topk_values /= topk_values.sum(dim=-1, keepdim=True) + topk_values = topk_values.to(expert_logits.dtype) + scores = scores.to(expert_logits.dtype) + + new_mask = torch.zeros_like(expert_logits) + new_mask.scatter_(-1, topk_idx, 1) + scores = expert_logits * new_mask + + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + scores[i, topk_idx[i, j]] = topk_values[i, j] + permute_info = routing_reference(scores, top_k, padding) + return permute_info, scores + + +def run_moe_dequant(args, quant_mode=["fp4"]): + # Permute + total_num_padded_tokens = args.permute_info["permutedBufferSize"] + expanded_idx_to_permuted_idx = args.permute_info[ + "expandedTokenIdxToPermutedIdx" + ].cpu() + num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu() + permute_output = torch.full( + (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + for i in range(args.num_tokens): + for j in range(args.top_k): + permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j] + permute_output[permuted_idx] = args.hidden_states[i] + # Gemm1 + gemm1_output = torch.full( + (total_num_padded_tokens, 2 * args.intermediate_size), + float("nan"), + device="cuda", + ).to(torch.float) + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = permute_output[i : i + my_num_tokens] + my_b = args.gemm1_weights[expert_idx] + my_c = my_a @ my_b.t() + gemm1_output[i : i + my_num_tokens] = my_c + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + + if args.use_routing_scales_on_input: + assert args.top_k == 1 + # For each token and its top_k experts + for token_idx in range(args.num_tokens): + for k in range(args.top_k): + # Get the permuted index for this token's k-th expert + expanded_idx = token_idx * args.top_k + k + permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] + expert_weight = args.permute_info["topKLogits"].to(torch.float) + # Get the expert weight for this token and expert + weight = expert_weight[token_idx, k] + # Scale the corresponding row in gemm1_output + gemm1_output[permuted_idx] *= weight + + # Activation + activation_output = torch.full( + (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" + ).to(torch.float) + + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = gemm1_output[i : i + my_num_tokens] + my_x1 = my_a[:, : args.intermediate_size] + my_x2 = my_a[:, args.intermediate_size :] + activation_output[i : i + my_num_tokens] = F.silu(my_x2) * my_x1 + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + + if quant_mode == "fp4": + activation_output, c_global_sf = quant_dequant_fp4( + activation_output.to(torch.bfloat16), False, True + ) + activation_output = activation_output.to(torch.float) + args.c_global_sf = c_global_sf + + # Gemm2 + gemm2_output = torch.full( + (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + i = 0 + for expert_idx in range(args.num_experts): + my_num_tokens = num_tokens_per_expert[expert_idx] + if my_num_tokens == 0: + continue + my_a = activation_output[i : i + my_num_tokens] + my_b = args.gemm2_weights[expert_idx] + my_c = my_a @ my_b.t() + gemm2_output[i : i + my_num_tokens] = my_c + i += my_num_tokens + i = (i + args.padding - 1) // args.padding * args.padding + # Finalize + expert_weight = args.permute_info["topKLogits"].to(torch.float) + finalize_output = torch.full( + (args.num_tokens, args.hidden_size), float("nan"), device="cuda" + ).to(torch.float) + for i in range(args.num_tokens): + acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda") + for top_k_idx in range(args.top_k): + expanded_idx = i * args.top_k + top_k_idx + permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] + original_vector = gemm2_output[permuted_idx] + weight = ( + expert_weight[i, top_k_idx] + if not args.use_routing_scales_on_input + else 1.0 + ) + acc += original_vector * weight + finalize_output[i] = acc + return finalize_output + def e2m1_and_ufp8_scale_to_float_tensor_v2( e2m1_tensor: torch.Tensor, @@ -106,6 +392,63 @@ def e2m1_and_ufp8_scale_to_float_tensor_v2( return float_tensor +def e2m1_and_ufp8_scale_batches( + mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor, + global_scale_tensor: torch.Tensor, + sf_vec_size: int, + ufp8_type: int = 1, +): + num_batches = mat_fp4.size(0) + + scale_tensor = scale_tensor.view(num_batches, -1) + + tensors = [ + e2m1_and_ufp8_scale_to_float_tensor_v2( + mat_fp4[b, :, :], scale_tensor[b, :], global_scale_tensor[b], sf_vec_size + ) + for b in range(num_batches) + ] + + result = torch.stack(tensors) + + return result + + +def run_moe_reference_fp4(args): + sf_vec_size = 16 + + hidden_states_dequant = e2m1_and_ufp8_scale_to_float_tensor_v2( + args.hidden_states, + args.hidden_states_scale, + 1 / args.hidden_states_scale_global, + sf_vec_size, + ).cuda() + + gemm1_weights_dequant = e2m1_and_ufp8_scale_batches( + args.gemm1_weights, args.gemm1_scales, 1 / args.gemm1_scales_global, sf_vec_size + ).cuda() + + gemm2_weights_dequant = e2m1_and_ufp8_scale_batches( + args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global, sf_vec_size + ).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + ) + + return run_moe_dequant(args_dequant, "fp4"), args_dequant def quant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): @@ -167,6 +510,169 @@ def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): return a_pt.cuda(), a_global_sf +def check_accuracy(a, b, atol, rtol, percent): + if torch.any(torch.isnan(a)): + raise Exception("NaN in a") + if torch.any(torch.isnan(b)): + raise Exception("NaN in b") + if torch.any(torch.isinf(a)): + raise Exception("Inf in a") + if torch.any(torch.isinf(b)): + raise Exception("Inf in b") + assert a.shape == b.shape + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + "Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol) + ) + + +def create_expert_logits(num_token, num_experts, k): + """ + Create deterministic expert logits for testing where specific experts + are guaranteed to be selected for each token. + + Args: + num_token: Number of tokens + num_experts: Number of experts + k: Top-k value (number of experts to select per token) + + Returns: + logits: Expert logits tensor [num_token, num_experts] (CUDA bfloat16) + index: Expected top-k indices [num_token, k] (CUDA) + large_random: The large random values used [num_token, k] (CUDA) + """ + # 1. Create logits tensor + logits = torch.zeros(num_token, num_experts) + + # 2. Set index sequence + final_size = num_token * k + repeat_count = math.ceil(final_size / num_experts) + indices = torch.arange(num_experts, dtype=torch.int32) + indices = indices.repeat(repeat_count) + indices = indices[:final_size] + index = indices.view(num_token, k).contiguous() + + # 3. Generate large random numbers + large_random = torch.randint(5, 11, (num_token, k), dtype=torch.float32) + + # 4. Put the random number to the place we want + for token_id in range(num_token): + for j in range(k): + expert_idx = index[token_id, j] + logits[token_id, expert_idx] = large_random[token_id, j] + + # 5. Set smaller random numbers in other places + mask = logits == 0 + logits[mask] = torch.rand(mask.sum()) + + logits = torch.nn.functional.softmax(logits, dim=-1) + + # Convert to CUDA tensors with appropriate dtypes + logits = logits.to(device="cuda", dtype=torch.bfloat16) + index = index.to(device="cuda") + large_random = large_random.to(device="cuda") + + return logits, index, large_random + + +def compute_moe_reference_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, +): + """ + Compute the reference MoE output using dequantized operations with full routing support. + + Returns: + output_dequant_reference: Reference output tensor + args_dequant: Dequantized arguments for debugging + """ + use_ue8m0 = False + + # Quantize hidden states + ( + hidden_states_fp4_bytes, + hidden_states_scale_fp4_bytes, + hidden_states_scale_global, + ) = quant_fp4(hidden_states, use_ue8m0, True) + + # Quantize the weights for FC1 + gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( + quant_fp4_batches(gemm1_weights, num_experts, use_ue8m0, True) + ) + + # Quantize the weights for FC2 + gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( + quant_fp4_batches(gemm2_weights, num_experts, use_ue8m0, True) + ) + + # Generate routing info based on method + if routing_method_type == RoutingMethodType.DeepSeekV3: + permute_info, scores = routing_reference_no_aux( + expert_logits, + routing_bias, + top_k, + n_groups, + top_k_groups, + routed_scaling, + padding, + ) + elif routing_method_type == RoutingMethodType.Renormalize: + permute_info, scores = routing_reference_renormalize( + expert_logits, top_k, num_experts, padding + ) + elif routing_method_type == RoutingMethodType.RenormalizeNaive: + permute_info, scores = routing_reference_renormalize_naive( + expert_logits, top_k, num_experts, padding + ) + else: + raise NotImplementedError( + f"Routing method {routing_method_type} not implemented" + ) + + # Create arguments for reference computation + args = moe_args( + num_tokens, + num_experts, + hidden_size, + intermediate_size, + top_k, + padding, + hidden_states_fp4_bytes, + hidden_states_scale_fp4_bytes, + hidden_states_scale_global, + scores, + gemm1_weights_fp4_bytes, + gemm1_scales_fp4_bytes, + gemm1_scales_global, + gemm2_weights_fp4_bytes, + gemm2_scales_fp4_bytes, + gemm2_scales_global, + permute_info, + False, + ) + + # Run the reference implementation + output_dequant_reference, args_dequant = run_moe_reference_fp4(args) + + return output_dequant_reference, args_dequant, args + def compute_moe_actual_with_routing( num_tokens, @@ -368,7 +874,94 @@ def prepare_static_weights(): return output_dequant_actual -def run_one( +def compare_moe_outputs( + output_dequant_reference, + output_dequant_actual, + seed, + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + routing_method_type, +): + """ + Compare reference and actual MoE outputs and perform accuracy analysis. + + Raises: + Exception: If accuracy test fails + """ + # Use check_accuracy to validate - it will raise exception if test fails + check_accuracy( + output_dequant_reference, + output_dequant_actual, + atol=0.1, + rtol=0.85, + percent=0.925, + ) + + +@pytest.mark.parametrize("num_tokens", [1, 1024, 4096]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192]) +@pytest.mark.parametrize( + "routing_info", + [ + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + id="RoutingDSv3", + ), + pytest.param( + { + "num_experts": 72, + "top_k": 6, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + id="RoutingDSlite", + ), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + }, + id="RoutingRenormalize", + ), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.RenormalizeNaive, + }, + id="RoutingRenormalizeNaive", + ), + ], +) +def test_moe_nvfp4( num_tokens, hidden_size, intermediate_size, @@ -430,6 +1023,25 @@ def run_one( dtype=torch.bfloat16, ) + # Compute reference output with updated routing method handling + output_dequant_reference, args_dequant, args = compute_moe_reference_with_routing( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + expert_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + top_k, + padding, + n_groups, + top_k_groups, + routed_scaling, + routing_method_type, + ) + # Compute actual output using optimized kernel output_dequant_actual = compute_moe_actual_with_routing( num_tokens, @@ -452,12 +1064,15 @@ def run_one( args, ) -if __name__ == '__main__': - for config in test_configs: - for batch_size in BATCH_SIZES: - run_one( - num_tokens=batch_size, - hidden_size=config["hidden_size"], - intermediate_size=config["intermediate_size"], - routing_info=config["routing_info"], - ) + # Compare outputs - will raise exception if test fails + compare_moe_outputs( + output_dequant_reference, + output_dequant_actual, + seed, + num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + routing_method_type, + ) \ No newline at end of file From 3baee7e273cffc225117189b27a48f00afb54f6b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:11:26 +0800 Subject: [PATCH 26/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py | 129 ++++++++++++----------- 1 file changed, 68 insertions(+), 61 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py b/benchmarks/bench_trtllm_gen_fused_moe.py index d25ca5b7f..9a6c7abdd 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py @@ -901,66 +901,6 @@ def compare_moe_outputs( ) -@pytest.mark.parametrize("num_tokens", [1, 1024, 4096]) -@pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192]) -@pytest.mark.parametrize( - "routing_info", - [ - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - }, - id="RoutingDSv3", - ), - pytest.param( - { - "num_experts": 72, - "top_k": 6, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - }, - id="RoutingDSlite", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - }, - id="RoutingRenormalize", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.RenormalizeNaive, - }, - id="RoutingRenormalizeNaive", - ), - ], -) def test_moe_nvfp4( num_tokens, hidden_size, @@ -1075,4 +1015,71 @@ def test_moe_nvfp4( num_experts, top_k, routing_method_type, - ) \ No newline at end of file + ) + +# --------------------------------------------------------------------------- + +BATCH_SIZES = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, +] + +test_configs = [ + # NOTE MODIFIED ADD + *[ + { + "hidden_size": 7168, + "intermediate_size": 2048, + # RoutingDSv3 + "routing_info": { + # TODO correct? + "num_experts": num_experts, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + }, + } + for num_experts in [ + 288 // 1, + 288 // 2, + 288 // 4, + 288 // 8, + 288 // 16, + 288 // 32, + 288 // 48, + 288 // 72, + ] + ], +] + +if __name__ == '__main__': + for config in test_configs: + for batch_size in BATCH_SIZES: + test_moe_nvfp4( + num_tokens=batch_size, + hidden_size=config["hidden_size"], + intermediate_size=config["intermediate_size"], + routing_info=config["routing_info"], + ) From 57a18b6dd35afea49aa4c3e10eb236e9a6e606df Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:13:23 +0800 Subject: [PATCH 27/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py | 47 +++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py b/benchmarks/bench_trtllm_gen_fused_moe.py index 9a6c7abdd..96aa17f3f 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import json import math +import os +import time import pytest import torch @@ -29,6 +31,7 @@ shuffle_matrix_sf_a, ) from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +from flashinfer.testing.utils import bench_kineto class moe_args: @@ -844,6 +847,48 @@ def prepare_static_weights(): torch.float8_e4m3fn ).reshape(-1) + if 1: + trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR") + [time_gemm1, time_gemm2] = bench_kineto( + lambda: trtllm_fp4_block_scale_moe( + expert_logits, + routing_bias, + hidden_states_fp4, + hidden_states_scale_linear_fp4, + static_data["gemm1_weights_fp4_shuffled"], + static_data["gemm1_scales_fp4_shuffled"], + static_data["gemm2_weights_fp4_shuffled"], + static_data["gemm2_scales_fp4_shuffled"], + static_data["scale_c_fc1"], + static_data["scale_gate_fc1"], + static_data["scale_c_fc2"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + tile_tokens_dim, + routing_method_type, + do_finalize=True, + ), + kernel_names="TODO_what_name", + num_kernels_per_period=2, + trace_path=f"{trace_dir}/{time.time()}.trace.json.gz" if trace_dir else None, + ) + + # NOTE MODIFIED + print(f"MAIN_OUTPUT=" + json.dumps(dict( + batch_size=batch_size, + num_experts=num_experts, + top_k=top_k, + intermediate_size=intermediate_size, + time_gemm1_us=time_gemm1 * 1e6, + time_gemm2_us=time_gemm2 * 1e6, + ))) + output = trtllm_fp4_block_scale_moe( expert_logits, routing_bias, From 806cf11323e01404b0db7b42541f3c8bd9f10aac Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 24 Jul 2025 14:14:13 +0800 Subject: [PATCH 28/28] more --- benchmarks/bench_trtllm_gen_fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe.py b/benchmarks/bench_trtllm_gen_fused_moe.py index 96aa17f3f..e163f22b1 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe.py +++ b/benchmarks/bench_trtllm_gen_fused_moe.py @@ -1107,7 +1107,7 @@ def test_moe_nvfp4( }, } for num_experts in [ - 288 // 1, + # 288 // 1, # not supported 288 // 2, 288 // 4, 288 // 8,