diff --git a/.github/workflows/third-party-benchmarks.yml b/.github/workflows/third-party-benchmarks.yml index 59a1a9158f..c1687ddba0 100644 --- a/.github/workflows/third-party-benchmarks.yml +++ b/.github/workflows/third-party-benchmarks.yml @@ -82,6 +82,25 @@ jobs: cd benchmarks pip install . + - name: Run sglang benchmark int8 + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} + run: | + source ./scripts/capture-hw-details.sh + + ./scripts/test-triton.sh --install-sglang --skip-pip-install --skip-pytorch-install + cd benchmarks/third_party/sglang + python scaled_mm_benchmark.py --reports $REPORTS + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-int8-report.csv --tag $TAG --benchmark scaled-mm-int8 --param_cols="M,N,K" --bgroup sglang + + - name: Run sglang benchmark with fp8 + if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'sglang')) }} + run: | + source ./scripts/capture-hw-details.sh + + cd benchmarks/third_party/sglang + FP8="1" python scaled_mm_benchmark.py --reports $REPORTS + python ../vllm/transform_results.py $REPORTS/scaled_mm_benchmark.csv $REPORTS/scaled-mm-fp8-report.csv --tag $TAG --benchmark scaled-mm-fp8 --param_cols="M,N,K" --bgroup sglang + - name: Run vllm benchmarks bf16 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} run: | @@ -92,7 +111,8 @@ jobs: cd benchmarks/third_party/vllm python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-report.csv --tag $TAG --benchmark moe-bf16-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm + - name: Run vllm benchmarks fp8 if: ${{ steps.install-benchmarks.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'vllm')) }} @@ -101,7 +121,8 @@ jobs: cd benchmarks/third_party/vllm FP8="1" python batched_moe_benchmark.py --reports $REPORTS - python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark + python transform_results.py $REPORTS/moe-gemm-performance.csv $REPORTS/moe-gemm-fp8-report.csv --tag $TAG --benchmark moe-fp8-benchmark --param_cols="num_experts,max_tokens_per_expert,K,N" --bgroup vllm + - name: Run Liger-Kernel benchmarks if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'liger')) }} diff --git a/benchmarks/third_party/sglang/scaled_mm_benchmark.py b/benchmarks/third_party/sglang/scaled_mm_benchmark.py new file mode 100644 index 0000000000..d79850a3b9 --- /dev/null +++ b/benchmarks/third_party/sglang/scaled_mm_benchmark.py @@ -0,0 +1,391 @@ +# From +# https://github.com/sgl-project/sglang/blob/6d0364681c8b1abc132cc88f1bb0b7a8a352628f/test/srt/quant/test_triton_scaled_mm.py +# https://github.com/sgl-project/sglang/blob/6d0364681c8b1abc132cc88f1bb0b7a8a352628f/python/sglang/srt/layers/quantization/fp8_kernel.py +import os +from typing import Optional, List + +import torch +import triton +import triton.language as tl + +import triton_kernels_benchmark as benchmark_suite + +from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm + + +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +def get_matmul_batched_autotune_configs() -> List[triton.Config]: + configs = [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2, 3] + ] + [ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) + for s in [2] + for (m, w) in ([('large', 32), ('small', 64)]) + ] + [ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + for s in [2] + ] + [ + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4) + for s in [2] + ] + return configs + + +@triton.jit +def scaled_mm_kernel_td( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am: tl.int64, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + # offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + # masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + # masks_bn = offsets_bn < N + + # offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + # offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + # offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = tl.arange(0, BLOCK_SIZE_SCALE_A) + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + masks_scale_bn = offsets_scale_bn < N + + a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)) + b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # a_ptrs = a_ptr + offsets_a + # b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + off_k = 0 + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # masks_k = offsets_k < K + # masks_a = masks_am[:, None] & masks_k[None, :] + # a = tl.load(a_ptrs, mask=masks_a) + + # masks_b = masks_k[:, None] & masks_bn[None, :] + # b = tl.load(b_ptrs, mask=masks_b) + + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) + # accumulator += tl.dot(a, b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + off_k += BLOCK_SIZE_K + + # offsets_k += BLOCK_SIZE_K + # a_ptrs += BLOCK_SIZE_K * stride_ak + # b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + # c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + # tl.store(c_ptrs, c, mask=c_mask) + c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + + +# input - [M, K] +# weight - [K, N] +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +def triton_scaled_mm_td( + input: torch.Tensor, # pylint: disable=redefined-builtin + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True, +) -> torch.Tensor: + M, K = input.shape + N = weight.shape[1] + + assert N > 0 and K > 0 and M > 0 + assert weight.shape[0] == K + assert input.dtype == weight.dtype + + scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a + scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b + + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N) + assert out_dtype.is_floating_point + assert bias is None or bias.is_floating_point() + assert is_weak_contiguous(input) + assert is_weak_contiguous(weight) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + else: + raise NotImplementedError('Only heuristic-based tile size selection is supported currently.') + + block_size_m, block_size_n, block_size_k = tile_shape + + block_size_sa = 1 if has_scalar(scale_a) else block_size_m + block_size_sb = 1 if has_scalar(scale_b) else block_size_n + + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 + + # A = input, B = weight, C = result + # A = M x K, B = K x N, C = M x N + scaled_mm_kernel_td[grid]( + input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb, + ) + return result + + +torch.set_default_device('xpu') +device = 'xpu' + + +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Reference implementation using float32 for stability""" + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a.to(torch.float32) * out * scale_b.to(torch.float32).T + if bias is not None: + out = out + bias.to(torch.float32) + return out.to(out_dtype) + + +def _make_inputs(M, K, N, in_dtype): + if in_dtype == torch.int8: + a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device=device) + b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device=device) + else: # fp8 + # Adding zero help with nan for some reason, without it there will be some accidental nans + a = (0 + torch.clamp(0.5 * torch.randn((M, K), dtype=torch.float16, device=device), -0.25, 0.25)).to(in_dtype) + b = 0.5 * torch.randn((K, N), dtype=torch.float16, device=device) + b = torch.clamp(b, -0.25, 0.25) + # Adding zero help with nan for some reason, without it there will be some accidental nans + b = (0 + b).to(in_dtype) + return a, b + + +X_VALS = sum([[ # + # [M, 128, 128], + [M, 1024, 4096], [M, 4096, 4096], [M, 4096, 4096 * 4] +] for M in [1, 8, 128, 1024, 4096]], []) + + +def get_scaled_mm_benchmark( + providers_filter: Optional[list[str]] = None, + fp8=False, + plot_name: str = 'scaled_mm_benchmark', +): + supported_providers = { + 'triton': 'triton', + 'triton-td': 'triton-td', + 'pytorch': 'pytorch-deqmm', + } + providers = benchmark_suite.filter_providers(supported_providers, providers_filter) + + @benchmark_suite.perf_report( + benchmark_suite.Benchmark( + x_names=['M', 'N', 'K'], + x_vals=X_VALS, + line_arg='provider', + line_vals=list(providers.keys()), + line_names=list(providers.values()), + styles=[('green', '-'), ('blue', '--'), ('red', ':')], + ylabel=['GB/s', 'TFlops'], + plot_name=plot_name, + args={}, + )) + def benchmark(M, N, K, provider, with_bias=False): + torch.manual_seed(10) + n_warmup = 600 + + quantiles = [0.5, 0, 1.0] + + if fp8: + in_dtype, out_dtype = torch.float8_e4m3fn, torch.float32 + else: + in_dtype, out_dtype = torch.int8, torch.bfloat16 + + x, weight = _make_inputs(M, K, N, in_dtype) + scale_a = 0.1 + 0.05 * torch.rand((M, 1), dtype=torch.float32, device=device) + scale_b = 0.1 + 0.05 * torch.rand((N, 1), dtype=torch.float32, device=device) + bias = (0.01 * torch.randn((M, N), dtype=out_dtype, device=device) if with_bias else None) + + def torch_fn(): + return torch_scaled_mm(x, weight, scale_a, scale_b, bias) + + # Use relaxed tolerances + rtol = 0.15 if in_dtype == torch.int8 else 0.25 + atol = 0.1 if in_dtype == torch.int8 else 0.15 + + if provider == 'pytorch': + # PyTorch reference implementation using native_batched_masked_quant_matmul + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + torch_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + elif provider in ('triton', 'triton-td'): + invoke_kernel = triton_scaled_mm if provider == 'triton' else triton_scaled_mm_td + + def triton_fn(): + return invoke_kernel(x, weight, scale_a, scale_b, out_dtype, bias) + + benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg='triton to torch') + + _, min_ms, max_ms, mean_ms, cv = benchmark_suite.do_bench( + triton_fn, + n_warmup=n_warmup, + n_repeat=10, + quantiles=quantiles, + ) + + else: + raise NotImplementedError(f'Unsupported provider {provider}') + + def gbps(ms): + total_bytes = in_dtype.itemsize * (M * K + K * N) + out_dtype.itemsize * M * N + return total_bytes * (1e-9) / (ms * 1e-3) + + def tflops(ms): + total_flops = M * N * K * 2 + return total_flops * (1e-12) / (ms * 1e-3) + + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv + + return benchmark + + +if __name__ == '__main__': + _benchmark_mm = get_scaled_mm_benchmark(fp8=(os.getenv('FP8', '0') == '1'), ) + _benchmark_mm.run(show_plots=False, print_data=True) diff --git a/benchmarks/third_party/sglang/sglang-bench-fix.patch b/benchmarks/third_party/sglang/sglang-bench-fix.patch new file mode 100644 index 0000000000..7c23b93a09 --- /dev/null +++ b/benchmarks/third_party/sglang/sglang-bench-fix.patch @@ -0,0 +1,208 @@ +diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py +index 3aaf301bb..68b6520d4 100644 +--- a/python/sglang/srt/layers/linear.py ++++ b/python/sglang/srt/layers/linear.py +@@ -18,9 +18,9 @@ from sglang.srt.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + ) +-from sglang.srt.distributed.device_communicators.pynccl_allocator import ( +- use_symmetric_memory, +-) ++# from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ++ # use_symmetric_memory, ++# ) + from sglang.srt.layers.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, +diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py +index df0658f86..e69de29bb 100644 +--- a/python/sglang/srt/layers/quantization/__init__.py ++++ b/python/sglang/srt/layers/quantization/__init__.py +@@ -1,173 +0,0 @@ +-# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +-from __future__ import annotations +- +-import builtins +-import inspect +-from typing import TYPE_CHECKING, Dict, Optional, Type +- +-import torch +- +-try: +- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig +- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig +- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config +- from vllm.model_executor.layers.quantization.gguf import GGUFConfig +- from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( +- GPTQMarlin24Config, +- ) +- from vllm.model_executor.layers.quantization.marlin import MarlinConfig +- from vllm.model_executor.layers.quantization.qqq import QQQConfig +- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig +- +- VLLM_AVAILABLE = True +-except ImportError as e: +- VLLM_AVAILABLE = False +- VLLM_IMPORT_ERROR = e +- +- # Define empty classes as placeholders when vllm is not available +- class DummyConfig: +- def override_quantization_method(self, *args, **kwargs): +- return None +- +- AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( +- ExpertsInt8Config +- ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = ( +- DummyConfig +- ) +- +- +-from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig +-from sglang.srt.layers.quantization.base_config import QuantizationConfig +-from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config +-from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( +- CompressedTensorsConfig, +-) +-from sglang.srt.layers.quantization.fp8 import Fp8Config +-from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config +-from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +-from sglang.srt.layers.quantization.modelopt_quant import ( +- ModelOptFp4Config, +- ModelOptFp8Config, +-) +-from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +-from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config +-from sglang.srt.layers.quantization.petit import PetitNvFp4Config +-from sglang.srt.layers.quantization.qoq import QoQConfig +-from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config +-from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config +-from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +-from sglang.srt.utils import is_cuda, is_hip, mxfp_supported +- +-_is_mxfp_supported = mxfp_supported() +- +-if TYPE_CHECKING: +- from sglang.srt.layers.moe.topk import TopKOutput +- +-# Base quantization methods that don't depend on vllm +-BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { +- "fp8": Fp8Config, +- "blockwise_int8": BlockInt8Config, +- "modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8 +- "modelopt_fp8": ModelOptFp8Config, +- "modelopt_fp4": ModelOptFp4Config, +- "w8a8_int8": W8A8Int8Config, +- "w8a8_fp8": W8A8Fp8Config, +- "awq": AWQConfig, +- "awq_marlin": AWQMarlinConfig, +- "gptq": GPTQConfig, +- "gptq_marlin": GPTQMarlinConfig, +- "moe_wna16": MoeWNA16Config, +- "compressed-tensors": CompressedTensorsConfig, +- "qoq": QoQConfig, +- "w4afp8": W4AFp8Config, +- "petit_nvfp4": PetitNvFp4Config, +- "fbgemm_fp8": FBGEMMFp8Config, +-} +- +- +-if is_cuda(): +- BASE_QUANTIZATION_METHODS.update( +- { +- "quark": Mxfp4Config, +- "mxfp4": Mxfp4Config, +- } +- ) +-elif _is_mxfp_supported and is_hip(): +- from sglang.srt.layers.quantization.quark.quark import QuarkConfig +- +- BASE_QUANTIZATION_METHODS.update( +- { +- "quark": QuarkConfig, +- "mxfp4": Mxfp4Config, +- } +- ) +-# VLLM-dependent quantization methods +-VLLM_QUANTIZATION_METHODS = { +- "aqlm": AQLMConfig, +- "deepspeedfp": DeepSpeedFPConfig, +- "tpu_int8": Int8TpuConfig, +- "marlin": MarlinConfig, +- "gguf": GGUFConfig, +- "gptq_marlin_24": GPTQMarlin24Config, +- "bitsandbytes": BitsAndBytesConfig, +- "qqq": QQQConfig, +- "experts_int8": ExpertsInt8Config, +-} +- +-QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS} +- +- +-def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: +- if quantization not in QUANTIZATION_METHODS: +- raise ValueError( +- f"Invalid quantization method: {quantization}. " +- f"Available methods: {list(QUANTIZATION_METHODS.keys())}" +- ) +- if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE: +- raise ValueError( +- f"{quantization} quantization requires some operators from vllm. " +- f"Please install vllm by `pip install vllm==0.9.0.1`\n" +- f"Import error: {VLLM_IMPORT_ERROR}" +- ) +- +- return QUANTIZATION_METHODS[quantization] +- +- +-original_isinstance = builtins.isinstance +- +- +-def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): +- """ +- Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig +- can recognize sglang layers +- """ +- if not VLLM_AVAILABLE: +- return +- +- if reverse: +- builtins.isinstance = original_isinstance +- return +- +- from vllm.model_executor.layers.fused_moe import FusedMoE +- from vllm.model_executor.layers.linear import LinearBase +- from vllm.model_executor.layers.vocab_parallel_embedding import ( +- VocabParallelEmbedding, +- ) +- +- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase +- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE +- from sglang.srt.layers.vocab_parallel_embedding import ( +- VocabParallelEmbedding as PatchedVocabParallelEmbedding, +- ) +- +- def patched_isinstance(obj, classinfo): +- if classinfo is LinearBase: +- return original_isinstance(obj, PatchedLinearBase) +- if classinfo is FusedMoE: +- return original_isinstance(obj, PatchedFusedMoE) +- if classinfo is VocabParallelEmbedding: +- return original_isinstance(obj, PatchedVocabParallelEmbedding) +- return original_isinstance(obj, classinfo) +- +- builtins.isinstance = patched_isinstance +diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py +index e98b7f6ff..9b4761a30 100644 +--- a/python/sglang/srt/layers/quantization/fp8_kernel.py ++++ b/python/sglang/srt/layers/quantization/fp8_kernel.py +@@ -23,7 +23,7 @@ import torch + import triton + import triton.language as tl + +-from sglang.srt.layers import deep_gemm_wrapper ++# from sglang.srt.layers import deep_gemm_wrapper + from sglang.srt.utils import ( + align, + direct_register_custom_op, diff --git a/benchmarks/third_party/sglang/sglang-fix.patch b/benchmarks/third_party/sglang/sglang-test-fix.patch similarity index 100% rename from benchmarks/third_party/sglang/sglang-fix.patch rename to benchmarks/third_party/sglang/sglang-test-fix.patch diff --git a/benchmarks/third_party/vllm/transform_results.py b/benchmarks/third_party/vllm/transform_results.py index 66ada1157c..e5baa9ab42 100644 --- a/benchmarks/third_party/vllm/transform_results.py +++ b/benchmarks/third_party/vllm/transform_results.py @@ -11,14 +11,20 @@ def parse_args(): parser = argparse.ArgumentParser(description='Parse MoE benchmark CSV') parser.add_argument('source', help='Path to the MoE benchmark CSV file') parser.add_argument('target', help='Path to output CSV file') + parser.add_argument( + '--param_cols', + help='Names of parameter columns, separated by commas.', + required=True, + ) parser.add_argument('--tag', help='Tag for the benchmark run', default='') - parser.add_argument('--benchmark', help='moe-benchmark', default='') + parser.add_argument('--benchmark', help='moe-benchmark', required=True) + parser.add_argument('--bgroup', help='Benchmark group', required=True) return parser.parse_args() -def parse_moe_csv(csv_file_path, tag, benchmark): - """Parse the MoE benchmark CSV and extract performance metrics.""" +def parse_csv(csv_file_path, tag, bench_group, benchmark, param_cols): + """Parse the benchmark CSV and extract performance metrics.""" df = pd.read_csv(csv_file_path) @@ -26,13 +32,7 @@ def parse_moe_csv(csv_file_path, tag, benchmark): current_datetime = datetime.now().isoformat() # Create params for all rows vectorized - df['params'] = df.apply( - lambda row: json.dumps({ - 'num_experts': int(row['num_experts']), - 'max_tokens_per_expert': int(row['max_tokens_per_expert']), - 'K': int(row['K']), - 'N': int(row['N']), - }), axis=1) + df['params'] = df.apply(lambda row: json.dumps({p: int(row[p]) for p in param_cols}), axis=1) # Define compiler columns compilers = [('triton', 'triton-TFlops'), ('pytorch', 'pytorch-TFlops'), ('triton-td', 'triton-td-TFlops')] @@ -46,7 +46,7 @@ def parse_moe_csv(csv_file_path, tag, benchmark): if len(valid_rows) > 0: valid_rows['run_uuid'] = run_uuid valid_rows['ts'] = current_datetime - valid_rows['benchmark_group'] = 'moe-benchmark' + valid_rows['benchmark_group'] = bench_group valid_rows['benchmark'] = benchmark valid_rows['compiler'] = compiler_name valid_rows['value_name'] = 'tflops' @@ -90,7 +90,8 @@ def main(): if not os.path.exists(args.source): raise ValueError(f'Error: CSV file {args.source} not found') - df_results = parse_moe_csv(args.source, args.tag, args.benchmark) + param_cols = args.param_cols.split(',') + df_results = parse_csv(args.source, args.tag, args.bgroup, args.benchmark, param_cols) df_results.to_csv(args.target, index=False) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 69e72e945d..d9ba1f088c 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -619,7 +619,8 @@ run_sglang_install() { if ! pip list | grep "sglang" ; then cd sglang git checkout "$(<../benchmarks/third_party/sglang/sglang-pin.txt)" - git apply ../benchmarks/third_party/sglang/sglang-fix.patch + git apply ../benchmarks/third_party/sglang/sglang-test-fix.patch + git apply ../benchmarks/third_party/sglang/sglang-bench-fix.patch # That's how sglang assumes we'll pick out platform for now cp python/pyproject_xpu.toml python/pyproject.toml