diff --git a/vllm/envs.py b/vllm/envs.py index aa43106ccf50..280ba75683b8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -103,6 +103,7 @@ VLLM_USE_AITER_TRITON_SILU_MUL: bool = False VLLM_TRITON_FP4_GEMM_USE_ASM: bool = False VLLM_USE_AITER_TRITON_ROPE: bool = False + VLLM_USE_AITER_TRITON_GEMM: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -797,6 +798,12 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")), + # Whether to use aiter triton gemm. + # By default is disabled. + "VLLM_USE_AITER_TRITON_GEMM": + lambda: (os.getenv("VLLM_USE_AITER_TRITON_GEMM", "False").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e796f6729018..d34059f4a55d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -54,11 +54,11 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - # import aiter as rocm_aiter - - # return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - + # MI300's fp8nuz should be enough to detect if we call ck vs triton + if current_platform.is_fp8_fnuz() or not envs.VLLM_USE_AITER_TRITON_GEMM: + from aiter import gemm_a8w8_blockscale + elif envs.VLLM_USE_AITER_TRITON_GEMM: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 097f0346adc5..bd9fe4e7ba71 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -14,9 +14,6 @@ if current_platform.is_rocm(): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 -VLLM_USE_AITER_TRITON_GEMM = (os.getenv("VLLM_USE_AITER_TRITON_GEMM", - "False").lower() in ("true", "1")) - def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # Shuffle weight along the last dimension so that @@ -121,7 +118,7 @@ def rocm_unquantized_gemm_impl( x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) - if VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k): + if envs.VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k): return gemm_a16w16(x, weight, bias) if use_skinny is not True: