From 88f141eb8f889454e735cc033fd43f5448521774 Mon Sep 17 00:00:00 2001 From: Doug Lehr Date: Sat, 6 Sep 2025 12:05:08 -0500 Subject: [PATCH 1/3] Attempt to put ck blockscale back in for mi300 --- .../layers/quantization/utils/fp8_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e796f6729018..bdcc34065b21 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -54,11 +54,11 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - # import aiter as rocm_aiter - - # return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - + # MI300's fp8nuz should be enough to detect if we call ck vs triton + if current_platform.is_fp8_fnuz(): + from aiter import gemm_a8w8_blockscale + else: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) From 2e0dcd869a4028aaad79979f5d91a1d51382e7bd Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 8 Sep 2025 15:03:43 +0800 Subject: [PATCH 2/3] add fp8 gemm path choice for rocm_aiter_gemm_w8a8_blockscale Signed-off-by: zhuyuhua-v --- vllm/envs.py | 7 +++++++ vllm/model_executor/layers/quantization/utils/fp8_utils.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index aa43106ccf50..280ba75683b8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -103,6 +103,7 @@ VLLM_USE_AITER_TRITON_SILU_MUL: bool = False VLLM_TRITON_FP4_GEMM_USE_ASM: bool = False VLLM_USE_AITER_TRITON_ROPE: bool = False + VLLM_USE_AITER_TRITON_GEMM: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -797,6 +798,12 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")), + # Whether to use aiter triton gemm. + # By default is disabled. + "VLLM_USE_AITER_TRITON_GEMM": + lambda: (os.getenv("VLLM_USE_AITER_TRITON_GEMM", "False").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index bdcc34065b21..d34059f4a55d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -55,9 +55,9 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: # MI300's fp8nuz should be enough to detect if we call ck vs triton - if current_platform.is_fp8_fnuz(): + if current_platform.is_fp8_fnuz() or not envs.VLLM_USE_AITER_TRITON_GEMM: from aiter import gemm_a8w8_blockscale - else: + elif envs.VLLM_USE_AITER_TRITON_GEMM: from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) From 0755e0cb4d71a9f1a48ab78f7a09e9efcf435f96 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Tue, 9 Sep 2025 09:44:02 +0800 Subject: [PATCH 3/3] refresh env VLLM_USE_AITER_TRITON_GEMM Signed-off-by: zhuyuhua-v --- vllm/model_executor/layers/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 097f0346adc5..bd9fe4e7ba71 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -14,9 +14,6 @@ if current_platform.is_rocm(): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 -VLLM_USE_AITER_TRITON_GEMM = (os.getenv("VLLM_USE_AITER_TRITON_GEMM", - "False").lower() in ("true", "1")) - def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # Shuffle weight along the last dimension so that @@ -121,7 +118,7 @@ def rocm_unquantized_gemm_impl( x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) - if VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k): + if envs.VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k): return gemm_a16w16(x, weight, bias) if use_skinny is not True: