Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,32 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
import vllm.envs as envs

from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)


FUSED_OP = torch.ops._C.silu_and_mul_quant.default


def is_rocm_aiter_enabled() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER


if is_rocm_aiter_enabled():
lib = torch.library.Library("_C", "FRAGMENT")
lib.define("aiter_silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()")

Check failure on line 28 in vllm/compilation/activation_quant_fusion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/activation_quant_fusion.py:28:81: E501 Line too long (84 > 80)
def aiter_silu_and_mul_quant(
result: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor) -> None:
aiter.scaled_silu_and_mul(result, input, scale)

Check failure on line 33 in vllm/compilation/activation_quant_fusion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/compilation/activation_quant_fusion.py:33:9: F821 Undefined name `aiter`. Consider specifying `requires-python = ">= 3.10"` or `tool.ruff.target-version = "py310"` in your `pyproject.toml` file.
lib.impl("aiter_silu_and_mul_quant", aiter_silu_and_mul_quant, "CUDA")
FUSED_OP = torch.ops._C.aiter_silu_and_mul_quant.default


def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
Expand All @@ -31,7 +51,7 @@
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
at = auto_functionalized(FUSED_OP,
result=result,
input=input,
scale=scale)
Expand Down
Loading