diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 3f564506df0d..4cdeabbe108d 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import IntEnum from typing import Optional import torch @@ -8,6 +9,16 @@ from vllm.utils import direct_register_custom_op +class AITEROpMode(IntEnum): + NO_AITER = 0 + AITER = 1 + SWIZZLE_AITER = 2 + + +def use_swizzle(n: int, k: int, layout: tuple[int, int]) -> bool: + return n % layout[0] == 0 and k % (layout[1] * 2) == 0 + + def use_swizzle_gemm(n: int, k: int, dtype: torch.dtype) -> bool: multiple_of: int = 64 @@ -84,4 +95,37 @@ def rocm_aiter_tuned_gemm( out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, - ) \ No newline at end of file + ) + + @staticmethod + def gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + import aiter as rocm_aiter + + return rocm_aiter.gemm_a8w8_blockscale(A, + B, + As, + Bs, + dtype=output_dtype) + + @staticmethod + def gemm_a8w8_blockscale_bpreshuffle( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + import aiter as rocm_aiter + return rocm_aiter.gemm_a8w8_blockscale_bpreshuffle(A, + B, + As, + Bs, + dtype=output_dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4b80b2b5c46a..d850b9e0dc91 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,6 +11,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import AITEROpMode from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -232,10 +233,10 @@ def __init__(self, quant_config: Fp8Config): # AITER is only supported on ROCm and only for FP8_FNUZ # and at the moment are MI300 series - self.use_aiter_and_is_supported = (current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + self.use_aiter_and_is_supported: int = int( + current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) self.block_quant = self.quant_config.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" @@ -387,6 +388,17 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = self._maybe_pad_weight(weight) + if self.use_aiter_and_is_supported: + from vllm._aiter_ops import use_swizzle + layout = (16, 16) + if use_swizzle(*weight.shape, layout): + self.use_aiter_and_is_supported = \ + AITEROpMode.SWIZZLE_AITER.value + if (self.use_aiter_and_is_supported == \ + AITEROpMode.SWIZZLE_AITER.value): + from aiter.ops.shuffle import shuffle_weight + weight = shuffle_weight(weight, layout=(16, 16)) + # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2b86c6a71dcc..6dcf846765bb 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) @@ -53,49 +54,14 @@ def cutlass_scaled_mm( scale_b=Bs.T) -def rocm_aiter_gemm_w8a8_blockscale_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - import aiter as rocm_aiter - - return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - mutates_args=[], - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, - ) - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()): +if (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()): - import aiter as rocm_aiter - from aiter import get_hip_quant + import aiter as rocm_aiter + from aiter import get_hip_quant - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) def dispatch_w8a8_blockscale_func( @@ -110,8 +76,10 @@ def dispatch_w8a8_blockscale_func( ], torch.Tensor]: if use_cutlass: return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + if (use_aiter_and_is_supported == 1): + return aiter_ops.gemm_a8w8_blockscale + elif (use_aiter_and_is_supported == 2): + return aiter_ops.gemm_a8w8_blockscale_bpreshuffle return w8a8_block_fp8_matmul @@ -125,7 +93,7 @@ def apply_w8a8_block_fp8_linear( input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, + use_aiter_and_is_supported: int = 0, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods @@ -183,7 +151,9 @@ def apply_w8a8_block_fp8_linear( else: if use_aiter_and_is_supported: q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + input_2d.contiguous(), + quant_dtype=rocm_aiter.dtypes.fp8, + transpose_scale=use_aiter_and_is_supported == 2) else: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass)