Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from typing import Optional

import torch
Expand All @@ -8,6 +9,16 @@
from vllm.utils import direct_register_custom_op


class AITEROpMode(IntEnum):
NO_AITER = 0
AITER = 1
SWIZZLE_AITER = 2


def use_swizzle(n: int, k: int, layout: tuple[int, int]) -> bool:
return n % layout[0] == 0 and k % (layout[1] * 2) == 0


def use_swizzle_gemm(n: int, k: int, dtype: torch.dtype) -> bool:

multiple_of: int = 64
Expand Down Expand Up @@ -84,4 +95,37 @@ def rocm_aiter_tuned_gemm(
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
)

@staticmethod
def gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
import aiter as rocm_aiter

return rocm_aiter.gemm_a8w8_blockscale(A,
B,
As,
Bs,
dtype=output_dtype)

@staticmethod
def gemm_a8w8_blockscale_bpreshuffle(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
import aiter as rocm_aiter
return rocm_aiter.gemm_a8w8_blockscale_bpreshuffle(A,
B,
As,
Bs,
dtype=output_dtype)
20 changes: 16 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import AITEROpMode
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
Expand Down Expand Up @@ -232,10 +233,10 @@ def __init__(self, quant_config: Fp8Config):

# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = (current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())
self.use_aiter_and_is_supported: int = int(
current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())

self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
Expand Down Expand Up @@ -387,6 +388,17 @@ def process_weights_after_loading(self, layer: Module) -> None:

weight = self._maybe_pad_weight(weight)

if self.use_aiter_and_is_supported:
from vllm._aiter_ops import use_swizzle
layout = (16, 16)
if use_swizzle(*weight.shape, layout):
self.use_aiter_and_is_supported = \
AITEROpMode.SWIZZLE_AITER.value
if (self.use_aiter_and_is_supported == \
AITEROpMode.SWIZZLE_AITER.value):
from aiter.ops.shuffle import shuffle_weight
weight = shuffle_weight(weight, layout=(16, 16))

# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
Expand Down
60 changes: 15 additions & 45 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
Expand Down Expand Up @@ -53,49 +54,14 @@
scale_b=Bs.T)


def rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
import aiter as rocm_aiter

return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)


def rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:

m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y


if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
if (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):

import aiter as rocm_aiter
from aiter import get_hip_quant
import aiter as rocm_aiter
from aiter import get_hip_quant

aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)


def dispatch_w8a8_blockscale_func(
Expand All @@ -110,8 +76,10 @@
], torch.Tensor]:
if use_cutlass:
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
if (use_aiter_and_is_supported == 1):
return aiter_ops.gemm_a8w8_blockscale
elif (use_aiter_and_is_supported == 2):
return aiter_ops.gemm_a8w8_blockscale_bpreshuffle
return w8a8_block_fp8_matmul


Expand All @@ -125,7 +93,7 @@
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_aiter_and_is_supported: int = 0,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
Expand Down Expand Up @@ -173,7 +141,7 @@
use_cutlass = False

w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]

Check failure on line 144 in vllm/model_executor/layers/quantization/utils/fp8_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "dispatch_w8a8_blockscale_func" has incompatible type "int"; expected "bool" [arg-type]
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
Expand All @@ -183,7 +151,9 @@
else:
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
input_2d.contiguous(),
quant_dtype=rocm_aiter.dtypes.fp8,
transpose_scale=use_aiter_and_is_supported == 2)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
Expand Down