Skip to content

Commit 09ec68f

Browse files
k50112113dllehr-amd
authored andcommitted
add triton fp8 gemm support
1 parent 93ee1c5 commit 09ec68f

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ def __init__(self, quant_config: Fp8Config):
203203
# and at the moment are MI300 series
204204
self.use_aiter_and_is_supported = (current_platform.is_rocm()
205205
and envs.VLLM_ROCM_USE_AITER
206-
and envs.VLLM_ROCM_USE_AITER_LINEAR
207-
and current_platform.is_fp8_fnuz())
206+
and envs.VLLM_ROCM_USE_AITER_LINEAR)
208207

209208
self.block_quant = self.quant_config.weight_block_size is not None
210209
self.act_q_static = self.quant_config.activation_scheme == "static"

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
5454
block_size: list[int],
5555
output_dtype: torch.dtype = torch.float16,
5656
) -> torch.Tensor:
57-
import aiter as rocm_aiter
57+
# import aiter as rocm_aiter
5858

59-
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
59+
# return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
60+
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
61+
62+
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
6063

6164

6265
def rocm_aiter_gemm_w8a8_blockscale_fake(
@@ -185,7 +188,7 @@ def apply_w8a8_block_fp8_linear(
185188
block_size, input.dtype)
186189

187190
else:
188-
if use_aiter_and_is_supported:
191+
if use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
189192
q_input, x_scale = aiter_per1x128_quant(
190193
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
191194
else:

0 commit comments

Comments
 (0)