diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5..15d470cc15df 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -1,3 +1,5 @@ +#include "cuda_utils.h" + #include #include #include @@ -119,3 +121,59 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, double eps, double fp8_min, double fp8_max, + bool use_ue8m0) { + static constexpr int NUM_WARPS = 4; + + using Idx_t = uint32_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t G = cuda_utils::ceil_div(H, Idx_t(group_size * NUM_WARPS)); + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + int stride_counts_e = counts.stride(0); + + static constexpr int NUM_PARALLEL_TOKENS = 16; + dim3 grid(E * G, NUM_PARALLEL_TOKENS); + dim3 block(NUM_WARPS * 32); + + if (use_ue8m0) { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, true> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } else { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, false> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } +} diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb447..53a576d7ce7e 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -10,11 +10,13 @@ # (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (8, 16, 7168, 128, 0), + (8, 32, 7168, 128, 0), + (8, 64, 7168, 128, 0), + (8, 128, 7168, 128, 0), + (8, 256, 7168, 128, 0), + (8, 512, 7168, 128, 0), + (8, 1024, 7168, 128, 0), ] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6..440a7375e536 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -106,6 +106,53 @@ def silu_mul_fp8_quant_deep_gemm( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device).contiguous() + + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device).contiguous() + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + use_ue8m0 = is_deep_gemm_e8m0_used() + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( + y, + tokens_per_expert, + y_q, + y_s, + group_size, + eps, + fp8_min, + fp8_max, + use_ue8m0, + ) + + return y_q, y_s + + +def silu_mul_fp8_quant_deep_gemm_old( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales