Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
VLLM_USE_AITER_TRITON_SILU_MUL: bool = False
VLLM_TRITON_FP4_GEMM_USE_ASM: bool = False
VLLM_USE_AITER_TRITON_ROPE: bool = False
VLLM_USE_AITER_TRITON_GEMM: bool = False
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -797,6 +798,12 @@
lambda: (os.getenv("VLLM_USE_AITER_TRITON_ROPE", "False").lower() in
("true", "1")),

# Whether to use aiter triton gemm.
# By default is disabled.
"VLLM_USE_AITER_TRITON_GEMM":
lambda: (os.getenv("VLLM_USE_AITER_TRITON_GEMM", "False").lower() in
("true", "1")),

# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
Expand Down Expand Up @@ -1179,7 +1186,7 @@

# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "0"))),

Check failure on line 1189 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1189:81: E501 Line too long (94 > 80)

}

Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
# import aiter as rocm_aiter

# return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale

# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz() or not envs.VLLM_USE_AITER_TRITON_GEMM:
from aiter import gemm_a8w8_blockscale
elif envs.VLLM_USE_AITER_TRITON_GEMM:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)


Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
if current_platform.is_rocm():
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16

VLLM_USE_AITER_TRITON_GEMM = (os.getenv("VLLM_USE_AITER_TRITON_GEMM",
"False").lower() in ("true", "1"))


def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that
Expand Down Expand Up @@ -121,7 +118,7 @@ def rocm_unquantized_gemm_impl(
x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None)

if VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k):
if envs.VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k):
return gemm_a16w16(x, weight, bias)

if use_skinny is not True:
Expand Down
Loading