diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index 0650cbf3cc18..e8670ced428b 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -6,28 +6,31 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm, + silu_mul_fp8_quant_deep_gemm_cuda, +) +from vllm.model_executor.layers.fused_moe.old_batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm as gold, ) from vllm.platforms import current_platform -def benchmark(E, T, H, G=128, runs=50): +def benchmark(k, E, T, H, G=128, runs=100): current_platform.seed_everything(42) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() tokens_per_expert = torch.randint( T // 2, T, size=(E,), dtype=torch.int32, device="cuda" ) # Warmup - for _ in range(10): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + for _ in range(20): + k(y, tokens_per_expert, group_size=G) torch.cuda.synchronize() # Benchmark torch.cuda.synchronize() start = time.perf_counter() for _ in range(runs): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + k(y, tokens_per_expert, group_size=G) torch.cuda.synchronize() avg_time = (time.perf_counter() - start) / runs * 1000 @@ -52,26 +55,46 @@ def benchmark(E, T, H, G=128, runs=50): configs = [ - (8, 32, 1024), - (16, 64, 2048), - (32, 128, 4096), # DeepSeekV3 Configs - (256, 16, 7168), - (256, 32, 7168), - (256, 64, 7168), - (256, 128, 7168), - (256, 256, 7168), - (256, 512, 7168), - (256, 1024, 7168), + (8, 16, 7168), + (8, 32, 7168), + (8, 64, 7168), + (8, 128, 7168), + (8, 256, 7168), + (8, 512, 7168), + (8, 1024, 7168), + (9, 16, 7168), + (9, 32, 7168), + (9, 64, 7168), + (9, 128, 7168), + (9, 256, 7168), + (9, 512, 7168), + (9, 1024, 7168), + # (16, 64, 2048), + # (32, 128, 4096), + # (256, 16, 7168), + # (256, 32, 7168), + # (256, 64, 7168), + # (256, 128, 7168), + # (256, 256, 7168), + # (256, 512, 7168), + # (256, 1024, 7168), ] -print(f"GPU: {torch.cuda.get_device_name()}") + +print(f"GPU: {torch.cuda.get_device_name()} CUDA Kernel") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + time_ms, gflops, gbps = benchmark(silu_mul_fp8_quant_deep_gemm_cuda, E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + + +print(f"GPU: {torch.cuda.get_device_name()} Baseline") print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") print("-" * 50) for E, T, H in configs: - try: - time_ms, gflops, gbps = benchmark(E, T, H) - print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") - except Exception: - print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") + time_ms, gflops, gbps = benchmark(gold, E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") diff --git a/csrc/ops.h b/csrc/ops.h index 7a176a5c0032..3188c9d7d1ea 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -137,6 +137,13 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& input_global_scale); #endif +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, double eps, double fp8_min, double fp8_max, + bool use_ue8m0); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); @@ -354,4 +361,4 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5..6c02aff3b0be 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -1,3 +1,5 @@ +#include "cuda_utils.h" + #include #include #include @@ -9,6 +11,10 @@ #include "quantization/fp8/common.cuh" +#include + +#include "core/registration.h" + namespace vllm { template @@ -87,6 +93,402 @@ __global__ void act_and_mul_quant_kernel( } } } + +__device__ __forceinline__ float silu(float x) { + return x * (1.f / (1.f + expf(-x))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +__device__ __forceinline__ float warp_max(float v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +template +__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) { + auto smem_ptr = reinterpret_cast(_smem_ptr); + auto glob_ptr = reinterpret_cast(_glob_ptr); + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ __forceinline__ void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ __forceinline__ void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +} + +template <> +__device__ __forceinline__ void cp_async_wait<0>() { + asm volatile("cp.async.wait_all;\n" ::); +} + +__device__ __forceinline__ float clip(float v, float mmin, float mmax) { + return fminf(mmax, fmaxf(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, + __nv_bfloat16 mmin, + __nv_bfloat16 mmax) { + return __hmin(mmax, __hmax(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v, + __nv_bfloat162 mmin, + __nv_bfloat162 mmax) { + return __hmin2(mmax, __hmax2(v, mmin)); +} + +template +__global__ void silu_mul_fp8_quant_deep_gemm_kernel( + const scalar_t* __restrict__ _input, __nv_fp8_e4m3* __restrict__ _y_q, + float* __restrict__ _y_s, const uint32_t* __restrict__ counts, + + // sizes + Idx_t H, Idx_t G, + + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e, + + // quant params + float fp8_min, float fp8_max) { + static constexpr float EPS = 1e-10; + static constexpr uint32_t S_NUM_128 = + 2 * (GROUP_SIZE / 8) * NUM_WARPS * NUM_STAGES; + static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; + static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; + static constexpr uint32_t S_NUM_64 = S_NUM_128 * 2; + static constexpr uint32_t S_NUM_32 = S_NUM_64 * 2; + __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + __shared__ Idx_t s_counts[1]; + + const Idx_t tid = threadIdx.x; + const Idx_t warp_id = tid / WARP_SIZE; + const Idx_t lane_id = tid % WARP_SIZE; + + auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + + // block handles one (expert e, group g) + Idx_t pid = blockIdx.x; + Idx_t e = pid / G; + Idx_t g = pid % G; + + if (!tid) { + s_counts[0] = counts[e * stride_counts_e]; + } + + const Idx_t stride_i_t_128 = stride_i_t / 8u; + + __syncthreads(); + const Idx_t n_tokens = s_counts[0]; + + auto par_id = blockIdx.y; + auto chunk_size = (n_tokens + NUM_PARALLEL_TOKENS - 1) / NUM_PARALLEL_TOKENS; + auto n_tokens_lower = par_id * chunk_size; + auto n_tokens_upper = min(n_tokens, (par_id + 1) * chunk_size); + + // base offsets (element-based) + const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; + const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; + const Idx_t base_yq = + e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; + + Idx_t gate_off_128 = (base_i / 8u); + auto input_128_ptr = reinterpret_cast(_input); + auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % 64) + + stride_i_t_128 * n_tokens_lower; + auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; + + auto y_s_ptr = _y_s + base_ys + warp_id + stride_ys_t * n_tokens_lower; + + auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + 2 * lane_id + + stride_yq_t * n_tokens_lower; + + Idx_t t_load = n_tokens_lower, load_stage_id = 0; + auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); + auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; + auto load_and_advance_y_pred = [&] { + if (t_load < n_tokens_upper) { + auto stage_offset = + (load_stage_id % NUM_STAGES) * (NUM_WARPS * WARP_SIZE / 2); + + auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; + auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + + if (tid < HALF_THREAD_COUNT) { + cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); + gate_128_ptr += stride_i_t_128; + } else { + cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); + up_128_ptr += stride_i_t_128; + } + ++t_load; + ++load_stage_id; + } + cp_async_fence(); + }; + +#pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + auto s_gate_ptr = s_buff_compute_32 + warp_id * (GROUP_SIZE / 2) + lane_id; + auto s_up_ptr = s_gate_ptr + S_NUM_32 / 2; + + Idx_t stage_id{}; + for (Idx_t t = n_tokens_lower; t < n_tokens_upper; ++t) { + float y_max = EPS; + float results[4]; + + cp_async_wait(); + __syncthreads(); + + load_and_advance_y_pred(); + + const auto compute_pipeline_offset = + ((stage_id++) % NUM_STAGES) * (GROUP_SIZE / 2u) * NUM_WARPS; + + auto s_gate_compute = s_gate_ptr + compute_pipeline_offset; + auto s_up_compute = s_up_ptr + compute_pipeline_offset; + +#pragma unroll + for (int i = 0; i < 2; i++) { + float2 gate = silu2(__bfloat1622float2(*s_gate_compute)); + float2 upv = __bfloat1622float2(*s_up_compute); + + results[2 * i] = gate.x * upv.x; + results[2 * i + 1] = gate.y * upv.y; + + y_max = + fmaxf(y_max, fmaxf(fabsf(results[2 * i]), fabsf(results[2 * i + 1]))); + s_gate_compute += WARP_SIZE; + s_up_compute += WARP_SIZE; + } + + float y_s = warp_max(y_max) / fp8_max; + + if constexpr (USE_UE8M0) { + y_s = exp2f(ceilf(log2f(y_s))); + } + +#pragma unroll + for (Idx_t i = 0; i < 4; ++i) { + results[i] = clip(results[i] / y_s, fp8_min, fp8_max); + } + + auto local_y_q_ptr = reinterpret_cast(y_q_ptr); + const auto r4 = reinterpret_cast(results); + auto fp8x4 = __nv_fp8x4_e4m3(*r4); + auto resultfp8x2 = reinterpret_cast<__nv_fp8x2_e4m3*>(&fp8x4.__x); + +#pragma unroll + for (Idx_t i = 0; i < 2; ++i) { + auto res_u8 = reinterpret_cast(&resultfp8x2[i].__x); + local_y_q_ptr[0] = res_u8[0]; + local_y_q_ptr[1] = res_u8[1]; + local_y_q_ptr += 2 * WARP_SIZE; + } + + y_q_ptr += stride_yq_t; + + if (lane_id == 0) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_t; + } + } +} + +template +__global__ void __silu_mul_fp8_quant_deep_gemm_kernel( + const scalar_t* __restrict__ _input, __nv_fp8_e4m3* __restrict__ _y_q, + float* __restrict__ _y_s, const uint32_t* __restrict__ counts, + + // sizes + Idx_t H, Idx_t G, + + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e, + + // quant params + float fp8_min, float fp8_max) { + static constexpr float EPS = 1e-10; + static constexpr uint32_t S_NUM_128 = + 2 * (GROUP_SIZE / 8) * NUM_WARPS * NUM_STAGES; + static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; + static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; + static constexpr uint32_t S_NUM_64 = S_NUM_128 * 2; + static constexpr uint32_t S_NUM_32 = S_NUM_64 * 2; + __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + __shared__ Idx_t s_counts[1]; + + const Idx_t tid = threadIdx.x; + const Idx_t warp_id = tid / WARP_SIZE; + const Idx_t lane_id = tid % WARP_SIZE; + + auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + + // block handles one (expert e, group g) + Idx_t pid = blockIdx.x; + Idx_t e = pid / G; + Idx_t g = pid % G; + + if (!tid) { + s_counts[0] = counts[e * stride_counts_e]; + } + + const Idx_t stride_i_t_128 = stride_i_t / 8u; + + __syncthreads(); + const Idx_t n_tokens = s_counts[0]; + + auto par_id = blockIdx.y; + auto chunk_size = (n_tokens + NUM_PARALLEL_TOKENS - 1) / NUM_PARALLEL_TOKENS; + auto n_tokens_lower = par_id * chunk_size; + auto n_tokens_upper = min(n_tokens, (par_id + 1) * chunk_size); + + // base offsets (element-based) + const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; + const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; + const Idx_t base_yq = + e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; + + Idx_t gate_off_128 = (base_i / 8u); + auto input_128_ptr = reinterpret_cast(_input); + auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % 64) + + stride_i_t_128 * n_tokens_lower; + auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; + + auto y_s_ptr = _y_s + base_ys + warp_id + stride_ys_t * n_tokens_lower; + + auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + 2 * lane_id + + stride_yq_t * n_tokens_lower; + + Idx_t t_load = n_tokens_lower, load_stage_id = 0; + auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); + auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; + auto load_and_advance_y_pred = [&] { + if (t_load < n_tokens_upper) { + auto stage_offset = + (load_stage_id % NUM_STAGES) * (NUM_WARPS * WARP_SIZE / 2); + + auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; + auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + + if (tid < HALF_THREAD_COUNT) { + cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); + gate_128_ptr += stride_i_t_128; + } else { + cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); + up_128_ptr += stride_i_t_128; + } + ++t_load; + ++load_stage_id; + } + cp_async_fence(); + }; + +#pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + auto s_gate_ptr = s_buff_compute_32 + warp_id * (GROUP_SIZE / 2) + lane_id; + auto s_up_ptr = s_gate_ptr + S_NUM_32 / 2; + using bf16x2 = __nv_bfloat162; + using bf16 = __nv_bfloat16; + + Idx_t stage_id{}; + for (Idx_t t = n_tokens_lower; t < n_tokens_upper; ++t) { + bf16 y_max = EPS; + bf16x2 results[2]; + + cp_async_wait(); + __syncthreads(); + + load_and_advance_y_pred(); + + const auto compute_pipeline_offset = + ((stage_id++) % NUM_STAGES) * (GROUP_SIZE / 2u) * NUM_WARPS; + + auto s_gate_compute = s_gate_ptr + compute_pipeline_offset; + auto s_up_compute = s_up_ptr + compute_pipeline_offset; + +#pragma unroll + for (int i = 0; i < 2; i++) { + bf16x2 gate = __float22bfloat162_rn(__bfloat1622float2(*s_gate_compute)); + bf16x2 upv = *s_up_compute; + + results[i] = __hmul2(gate, upv); + + s_gate_compute += WARP_SIZE; + s_up_compute += WARP_SIZE; + } + + auto hmx2 = __hmax2(results[0], results[1]); + y_max = __hmax(hmx2.x, hmx2.y); + + auto y_s = warp_max(y_max) / bf16(fp8_max); + + const auto fp8min2 = make_bfloat162(fp8_min, fp8_min); + const auto fp8max2 = make_bfloat162(fp8_max, fp8_max); + + auto y_s2 = make_bfloat162(y_s, y_s); +#pragma unroll + for (Idx_t i = 0; i < 2; ++i) { + results[i] /= y_s2; + results[i] = clip(results[i], fp8min2, fp8max2); + } + + auto local_y_q_ptr = reinterpret_cast(y_q_ptr); +#pragma unroll + for (Idx_t i = 0; i < 2; ++i) { + auto resultfp8x2 = __nv_fp8x2_e4m3(results[i]); + auto res_u8 = reinterpret_cast(&resultfp8x2.__x); + local_y_q_ptr[0] = res_u8[0]; + local_y_q_ptr[1] = res_u8[1]; + local_y_q_ptr += 2 * WARP_SIZE; + } + + y_q_ptr += stride_yq_t; + + if (lane_id == 0) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_t; + } + } +} + } // namespace vllm // Launch activation, gating, and quantize kernel. @@ -119,3 +521,59 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, double eps, double fp8_min, double fp8_max, + bool use_ue8m0) { + static constexpr int NUM_WARPS = 4; + + using Idx_t = uint32_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t G = cuda_utils::ceil_div(H, Idx_t(group_size * NUM_WARPS)); + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + int stride_counts_e = counts.stride(0); + + static constexpr int NUM_PARALLEL_TOKENS = 16; + dim3 grid(E * G, NUM_PARALLEL_TOKENS); + dim3 block(NUM_WARPS * 32); + + if (use_ue8m0) { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, true> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } else { + vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t, + NUM_PARALLEL_TOKENS, false> + <<>>( + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), + reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr()), + y_s.data_ptr(), + reinterpret_cast(counts.data_ptr()), H, G, + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e, + static_cast(fp8_min), static_cast(fp8_max)); + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 56626a02c027..63f9f77e8498 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #define stride_tag #endif + ops.def( + "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s, int group_size, float eps, float fp8_min, float " + "fp8_max, bool use_ue8m0) -> ()"); + ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, + &silu_mul_fp8_quant_deep_gemm_cuda); + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb447..63f01f8483e6 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,16 +5,20 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + silu_mul_fp8_quant_deep_gemm_cuda) +from vllm.model_executor.layers.fused_moe.old_batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm as gold) from vllm.platforms import current_platform # (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (8, 16, 7168, 128, 0), + (8, 32, 7168, 128, 0), + (8, 64, 7168, 128, 0), + (8, 128, 7168, 128, 0), + (8, 256, 7168, 128, 0), + (8, 512, 7168, 128, 0), + (8, 1024, 7168, 128, 0), ] @@ -34,50 +38,17 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) - - # Reference implementation - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max = fp8_info.max - fp8_min = fp8_info.min - eps = 1e-10 - - # Compute silu activation and elementwise multiplication - y1 = y[..., :H] - y2 = y[..., H:] - silu_x = y1 * torch.sigmoid(y1) - merged = silu_x * y2 - - # Compute reference scales and quantized output, skipping padded tokens - for e in range(E): - nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), - dtype=torch.float32, - device="cuda") - ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") - for t in range(nt): - data = merged[e, t] - data_grp = data.view(H // group_size, group_size) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max - - scaled = data / scale.repeat_interleave(group_size) - clamped = scaled.clamp(fp8_min, fp8_max) - q = clamped.to(torch.float8_e4m3fn) - - ref_s[t] = scale - ref_q[t] = q - - y_se = y_s[e] - y_qe = y_q[e] - - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) - torch.testing.assert_close( - y_qe[:nt].to(torch.float32), - ref_q[:nt].to(torch.float32), - atol=2, - rtol=2e-1, - ) + y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y, + tokens_per_expert, + group_size=group_size, + eps=1e-10) + + gold_y_q, gold_y_s = gold(y, + tokens_per_expert, + group_size=group_size, + eps=1e-10) + torch.testing.assert_close(y_q.float(), + gold_y_q.float(), + atol=2, + rtol=2e-1) + torch.testing.assert_close(y_s.float(), gold_y_s.float()) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6..aea7acbfda54 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -19,48 +19,48 @@ @triton.jit def _silu_mul_fp8_quant_deep_gemm( - # Pointers ------------------------------------------------------------ - input_ptr, # 16-bit activations (E, T, 2*H) - y_q_ptr, # fp8 quantized activations (E, T, H) - y_s_ptr, # 16-bit scales (E, T, G) - counts_ptr, # int32 num tokens per expert (E) + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) - # Sizes --------------------------------------------------------------- + # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) - GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) - # Strides for input (elements) --------------------------------------- + # Strides for input (elements) --------------------------------------- stride_i_e, - stride_i_t, - stride_i_h, - - # Strides for y_q (elements) ----------------------------------------- - stride_yq_e, - stride_yq_t, - stride_yq_h, - - # Strides for y_s (elements) ----------------------------------------- - stride_ys_e, - stride_ys_t, - stride_ys_g, - - # Stride for counts (elements) - stride_counts_e, - - # Numeric params ------------------------------------------------------ - eps: tl.constexpr, - fp8_min: tl.constexpr, - fp8_max: tl.constexpr, - use_ue8m0: tl.constexpr, - - # Meta --------------------------------------------------------------- - BLOCK: tl.constexpr, - NUM_STAGES: tl.constexpr, -): + stride_i_t, + stride_i_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, + NUM_WARPS: tl.constexpr): G = H // GROUP_SIZE # map program id -> (e, g) - pid = tl.program_id(0) + pid = tl.program_id(axis=0) e = pid // G g = pid % G @@ -71,7 +71,8 @@ def _silu_mul_fp8_quant_deep_gemm( n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) cols = tl.arange(0, BLOCK).to(tl.int64) - mask = cols < BLOCK + load_mask = cols < BLOCK + write_mask = cols < BLOCK base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h base_gate_offset = base_input_offset + cols * stride_i_h @@ -82,10 +83,10 @@ def _silu_mul_fp8_quant_deep_gemm( for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, - mask=mask, + mask=load_mask, other=0.0).to(tl.float32) up = tl.load(input_ptr + base_up_offset + t * stride_i_t, - mask=mask, + mask=load_mask, other=0.0) gate = gate * (1.0 / (1.0 + tl.exp(-gate))) @@ -97,10 +98,59 @@ def _silu_mul_fp8_quant_deep_gemm( y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, + y_q, + mask=write_mask) tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) +def silu_mul_fp8_quant_deep_gemm_cuda( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device).contiguous() + + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device).contiguous() + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + use_ue8m0 = is_deep_gemm_e8m0_used() + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( + y, + tokens_per_expert, + y_q, + y_s, + group_size, + eps, + fp8_min, + fp8_max, + use_ue8m0, + ) + + return y_q, y_s + + def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert @@ -109,7 +159,7 @@ def silu_mul_fp8_quant_deep_gemm( ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales - y has shape (E, T, 2*H). The first half of the last dimension is + y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. Returns `(y_q, y_s)` where @@ -146,6 +196,8 @@ def silu_mul_fp8_quant_deep_gemm( stride_cnt_e = tokens_per_expert.stride()[0] + num_warps = 4 + # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim grid = (E * G, ) @@ -177,7 +229,8 @@ def silu_mul_fp8_quant_deep_gemm( is_deep_gemm_e8m0_used(), BLOCK=group_size, NUM_STAGES=4, - num_warps=1, + NUM_WARPS=num_warps, + num_warps=num_warps, ) return y_q, y_s @@ -297,8 +350,7 @@ def apply( fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), workspace1, expert_num_tokens, expected_m) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, - expert_num_tokens) + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(workspace1, expert_num_tokens) fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, expert_num_tokens, expected_m) diff --git a/vllm/model_executor/layers/fused_moe/old_batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/old_batched_deep_gemm_moe.py new file mode 100644 index 000000000000..68bc8859b53e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/old_batched_deep_gemm_moe.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, + is_deep_gemm_e8m0_used) + +logger = init_logger(__name__) + + +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + + cols * stride_yq_h) + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, + mask=mask, + other=0.0).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, + mask=mask, + other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + # The Deep Gemm kernels only support block size of 128 + DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + + def __init__(self, + max_num_tokens: int, + num_dispatchers: int, + block_shape: list[int], + per_act_token_quant=False): + """ + max_num_tokens: Maximum number of tokens from a DP Rank + num_dispatchers: The number of DP dispatchers. + block_shape: Block quantization block shape. + per_act_token_quant: Per activation token quantization flag. + """ + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert a.dim() == 2 + # FIXME (varun): We should be able to dispatch only from the leader + # DP ranks in the case of TP > 1. At the moment, all the Ranks + # end up sending their tokens. This needs to be fixed. + num_dispatchers = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = a.size( + 0) if self.max_num_tokens is None else self.max_num_tokens + workspace13 = (num_experts, max_num_tokens * num_dispatchers, + max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) + output = (num_experts, max_num_tokens * num_dispatchers, K) + return (workspace13, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + assert hidden_states.ndim == 3 + assert self.block_shape is not None + + a1q = hidden_states + _, N, K = w1.size() + + assert w2.size(1) == K + + E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) + + # (from deepgemm docs) : A value hint (which is a value on CPU) + # for the M expectation of each batch, correctly setting this value + # may lead to better performance. + expected_m = max_num_tokens + fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + workspace1, expert_num_tokens, expected_m) + + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, + expert_num_tokens) + + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, + expert_num_tokens, expected_m)