From 986918fa2c4d7905f7fbe368743349dfbf91dcc5 Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Fri, 26 Sep 2025 16:55:32 +0000 Subject: [PATCH 1/3] add aiter bpreshuffle block scaled gemm support Signed-off-by: tjtanaavllm --- vllm/_aiter_ops.py | 47 ++++++++++++++- .../model_executor/layers/quantization/fp8.py | 20 +++++-- .../layers/quantization/utils/fp8_utils.py | 57 +++++-------------- 3 files changed, 75 insertions(+), 49 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 3f564506df0d..2be65bce5058 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import IntEnum from typing import Optional import torch @@ -8,6 +9,16 @@ from vllm.utils import direct_register_custom_op +class AITEROpMode(IntEnum): + NO_AITER = 0 + AITER = 1 + SWIZZLE_AITER = 2 + + +def use_swizzle(n: int, k: int, layout: tuple[int, int]) -> bool: + return n % layout[0] == 0 and k % (layout[1] * 2) == 0 + + def use_swizzle_gemm(n: int, k: int, dtype: torch.dtype) -> bool: multiple_of: int = 64 @@ -84,4 +95,38 @@ def rocm_aiter_tuned_gemm( out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, - ) \ No newline at end of file + ) + + @staticmethod + def gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + import aiter as rocm_aiter + + return rocm_aiter.gemm_a8w8_blockscale(A, + B, + As, + Bs, + dtype=output_dtype) + + @staticmethod + def gemm_a8w8_blockscale_bpreshuffle( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + import aiter as rocm_aiter + As_t = As.transpose(0, 1).contiguous().view(*As.shape) + return rocm_aiter.gemm_a8w8_blockscale_bpreshuffle(A, + B, + As_t, + Bs, + dtype=output_dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4b80b2b5c46a..d850b9e0dc91 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,6 +11,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import AITEROpMode from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -232,10 +233,10 @@ def __init__(self, quant_config: Fp8Config): # AITER is only supported on ROCm and only for FP8_FNUZ # and at the moment are MI300 series - self.use_aiter_and_is_supported = (current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + self.use_aiter_and_is_supported: int = int( + current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) self.block_quant = self.quant_config.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" @@ -387,6 +388,17 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = self._maybe_pad_weight(weight) + if self.use_aiter_and_is_supported: + from vllm._aiter_ops import use_swizzle + layout = (16, 16) + if use_swizzle(*weight.shape, layout): + self.use_aiter_and_is_supported = \ + AITEROpMode.SWIZZLE_AITER.value + if (self.use_aiter_and_is_supported == \ + AITEROpMode.SWIZZLE_AITER.value): + from aiter.ops.shuffle import shuffle_weight + weight = shuffle_weight(weight, layout=(16, 16)) + # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2b86c6a71dcc..a09e4b54a607 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) @@ -53,49 +54,14 @@ def cutlass_scaled_mm( scale_b=Bs.T) -def rocm_aiter_gemm_w8a8_blockscale_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - import aiter as rocm_aiter - - return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - mutates_args=[], - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, - ) - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()): +if (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()): - import aiter as rocm_aiter - from aiter import get_hip_quant + import aiter as rocm_aiter + from aiter import get_hip_quant - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) def dispatch_w8a8_blockscale_func( @@ -110,8 +76,11 @@ def dispatch_w8a8_blockscale_func( ], torch.Tensor]: if use_cutlass: return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + if (use_aiter_and_is_supported == 1): + # return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + return aiter_ops.gemm_a8w8_blockscale + elif (use_aiter_and_is_supported == 2): + return aiter_ops.gemm_a8w8_blockscale_bpreshuffle return w8a8_block_fp8_matmul @@ -125,7 +94,7 @@ def apply_w8a8_block_fp8_linear( input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, + use_aiter_and_is_supported: int = 0, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods From f71e80b9598402d9493feb330373d98638d2ea72 Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Sat, 27 Sep 2025 03:05:26 +0000 Subject: [PATCH 2/3] use quant_hip to preshufflescale Signed-off-by: tjtanaavllm --- vllm/_aiter_ops.py | 3 +-- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 2be65bce5058..4cdeabbe108d 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -124,9 +124,8 @@ def gemm_a8w8_blockscale_bpreshuffle( output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: import aiter as rocm_aiter - As_t = As.transpose(0, 1).contiguous().view(*As.shape) return rocm_aiter.gemm_a8w8_blockscale_bpreshuffle(A, B, - As_t, + As, Bs, dtype=output_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a09e4b54a607..f648e0b18517 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -152,7 +152,9 @@ def apply_w8a8_block_fp8_linear( else: if use_aiter_and_is_supported: q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + input_2d.contiguous(), + quant_dtype=rocm_aiter.dtypes.fp8, + transpose_scale=use_aiter_and_is_supported == 2) else: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) From 4bbea11358b3a66a9f93b64d142c262b6371e265 Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Sat, 27 Sep 2025 03:26:27 +0000 Subject: [PATCH 3/3] clean code Signed-off-by: tjtanaavllm --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f648e0b18517..6dcf846765bb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -77,7 +77,6 @@ def dispatch_w8a8_blockscale_func( if use_cutlass: return cutlass_scaled_mm if (use_aiter_and_is_supported == 1): - # return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return aiter_ops.gemm_a8w8_blockscale elif (use_aiter_and_is_supported == 2): return aiter_ops.gemm_a8w8_blockscale_bpreshuffle