From 1aaac7f58053355d1364a09d8cd418a01a93b81d Mon Sep 17 00:00:00 2001 From: elvircrn Date: Mon, 1 Sep 2025 16:30:31 +0000 Subject: [PATCH 1/2] updated Signed-off-by: Robert Shaw --- csrc/quantization/activation_kernels.cu | 58 +++++++++++++++++++ .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 12 ++-- .../layers/fused_moe/batched_deep_gemm_moe.py | 47 +++++++++++++++ 3 files changed, 112 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5..15d470cc15df 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -1,3 +1,5 @@ +#include "cuda_utils.h" + #include #include #include @@ -119,3 +121,59 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, double eps, double fp8_min, double fp8_max, + bool use_ue8m0) { + static constexpr int NUM_WARPS = 4; + + using Idx_t = uint32_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t G = cuda_utils::ceil_div(H, Idx_t(group_size * NUM_WARPS)); + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + int stride_counts_e = counts.stride(0); + + static constexpr int NUM_PARALLEL_TOKENS = 16; + dim3 grid(E * G, NUM_PARALLEL_TOKENS); + dim3 block(NUM_WARPS * 32); + + if (use_ue8m0) { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, true> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } else { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, false> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } +} diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb447..53a576d7ce7e 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -10,11 +10,13 @@ # (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (8, 16, 7168, 128, 0), + (8, 32, 7168, 128, 0), + (8, 64, 7168, 128, 0), + (8, 128, 7168, 128, 0), + (8, 256, 7168, 128, 0), + (8, 512, 7168, 128, 0), + (8, 1024, 7168, 128, 0), ] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6..1203c5fc43ea 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -101,6 +101,53 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) +def silu_mul_fp8_quant_deep_gemm_cuda( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device).contiguous() + + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device).contiguous() + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + use_ue8m0 = is_deep_gemm_e8m0_used() + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( + y, + tokens_per_expert, + y_q, + y_s, + group_size, + eps, + fp8_min, + fp8_max, + use_ue8m0, + ) + + return y_q, y_s + + def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert From 6958b23fd16e53f1dc7449a5814ac1c8d9aa9094 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 1 Sep 2025 17:10:11 +0000 Subject: [PATCH 2/2] tweak name Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 1203c5fc43ea..440a7375e536 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -101,7 +101,7 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm_cuda( +def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, @@ -148,7 +148,7 @@ def silu_mul_fp8_quant_deep_gemm_cuda( return y_q, y_s -def silu_mul_fp8_quant_deep_gemm( +def silu_mul_fp8_quant_deep_gemm_old( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128,