diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b02d..58fe82c73cd8 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -9,12 +9,32 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +import vllm.envs as envs from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FUSED_OP = torch.ops._C.silu_and_mul_quant.default + + +def is_rocm_aiter_enabled() -> bool: + return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + + +if is_rocm_aiter_enabled(): + lib = torch.library.Library("_C", "FRAGMENT") + lib.define("aiter_silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()") + def aiter_silu_and_mul_quant( + result: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor) -> None: + aiter.scaled_silu_and_mul(result, input, scale) + lib.impl("aiter_silu_and_mul_quant", aiter_silu_and_mul_quant, "CUDA") + FUSED_OP = torch.ops._C.aiter_silu_and_mul_quant.default + + def silu_mul_pattern_static(result: torch.Tensor, result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): @@ -31,7 +51,7 @@ def silu_mul_pattern_static(result: torch.Tensor, def silu_mul_replacement_static(result: torch.Tensor, result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + at = auto_functionalized(FUSED_OP, result=result, input=input, scale=scale)