Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/assets/contributing/dockerfile-stages-dependency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 66 additions & 1 deletion tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.rocm_aiter_fusion import (
AITER_PER_TOKEN_QUANT_OP,
FUSED_SILU_MUL_PER_TOKEN_QUANT_OP,
VLLM_PER_TOKEN_QUANT_OP,
RocmAiterSiluMulFp8PerTokenQuantFusionPass,
)
from vllm.config import (
CompilationConfig,
CompilationMode,
Expand Down Expand Up @@ -161,8 +167,59 @@ def ops_in_model_after(self):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]


class TestSiluMulPerTokenQuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.hidden_size = hidden_size
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()

self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
pad_output=True,
)

self.use_aiter_quant = (
self.fp8_linear.quant_fp8.use_aiter
if hasattr(self.fp8_linear.quant_fp8, "use_aiter")
else False
)

weight_bf16 = torch.randn(hidden_size, hidden_size, dtype=torch.bfloat16)
weight_absmax = torch.max(torch.abs(weight_bf16), dim=0, keepdim=True)[
0
] # [1, hidden_size]
fp8_max = torch.finfo(FP8_DTYPE).max
self.wscale = (
(weight_absmax / fp8_max).clamp(min=1e-12).to(torch.float32).t()
) # [hidden_size, 1]
self.w = (weight_bf16 / weight_absmax).to(FP8_DTYPE).t()

def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, out_dtype=torch.bfloat16)
return x2, None # mimic 2-element output

def ops_in_model_before(self):
silu_mul_op = (
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul
)

quant_op = (
AITER_PER_TOKEN_QUANT_OP
if self.use_aiter_quant
else VLLM_PER_TOKEN_QUANT_OP
)

return [silu_mul_op, quant_op]

def ops_in_model_after(self):
return [FUSED_SILU_MUL_PER_TOKEN_QUANT_OP]


@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256, 4096])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
Expand All @@ -171,6 +228,7 @@ def ops_in_model_after(self):
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
(TestSiluMulPerTokenQuantModel, False, False),
],
)
# cuda_force_torch used to test torch code path on platforms that
Expand All @@ -186,6 +244,7 @@ def test_fusion_silu_and_mul_quant(
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
| TestSiluMulPerTokenQuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
Expand All @@ -195,6 +254,8 @@ def test_fusion_silu_and_mul_quant(
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
if model_class is TestSiluMulPerTokenQuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
Expand Down Expand Up @@ -224,6 +285,8 @@ def test_fusion_silu_and_mul_quant(
)

fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if model_class is TestSiluMulPerTokenQuantModel:
fusion_passes += [RocmAiterSiluMulFp8PerTokenQuantFusionPass(config)]

passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
Expand All @@ -246,6 +309,8 @@ def test_fusion_silu_and_mul_quant(
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
atol, rtol = 5e-2, 5e-2
elif model_class == TestSiluMulPerTokenQuantModel:
atol, rtol = 1e-2, 1e-2

torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
Expand Down
40 changes: 40 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,20 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
return x_fp8, out_bs


def _rocm_aiter_fused_silu_mul_per_token_quant_impl(
out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor
) -> None:
from aiter.ops.activation import fused_silu_mul_per_token_quant

fused_silu_mul_per_token_quant(out, scales, input)


def _rocm_aiter_fused_silu_mul_per_token_quant_fake(
out: torch.Tensor, scales: torch.Tensor, input: torch.Tensor
) -> None:
pass


# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False

Expand Down Expand Up @@ -901,6 +915,14 @@ def register_ops_once() -> None:
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_silu_mul_per_token_quant",
op_func=_rocm_aiter_fused_silu_mul_per_token_quant_impl,
mutates_args=["out", "scales"],
fake_impl=_rocm_aiter_fused_silu_mul_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

_OPS_REGISTERED = True

@staticmethod
Expand Down Expand Up @@ -1125,6 +1147,24 @@ def per_token_quant(
torch.ops.vllm.rocm_aiter_per_token_quant(out, x, scale)
return out, scale

@staticmethod
def fused_silu_mul_per_token_quant(
input: torch.Tensor,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
assert quant_dtype in [torch.int8, _FP8_DTYPE]
assert input.ndim == 2, "Input must be 2D tensor (num_tokens, 2*d)"
assert input.shape[-1] % 2 == 0, "Input last dimension must be even"

num_tokens, input_dim = input.shape
d = input_dim // 2

out = torch.empty((num_tokens, d), dtype=quant_dtype, device=input.device)
scales = torch.empty((num_tokens, 1), dtype=torch.float32, device=input.device)

torch.ops.vllm.rocm_aiter_fused_silu_mul_per_token_quant(out, scales, input)
return out, scales

@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterSiluMulFp8PerTokenQuantFusionPass,
)

if current_platform.is_cuda_alike():
Expand Down Expand Up @@ -132,6 +133,7 @@ def configure(self, config: VllmConfig):
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
self.passes += [RocmAiterSiluMulFp8PerTokenQuantFusionPass(config)]

# ROCm AITER all-reduce + RMSNorm fusion
if (
Expand Down
112 changes: 110 additions & 2 deletions vllm/compilation/rocm_aiter_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default

FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
FUSED_SILU_MUL_GROUP_QUANT_OP = (
torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
)
FUSED_SILU_MUL_PER_TOKEN_QUANT_OP = (
torch.ops.vllm.rocm_aiter_fused_silu_mul_per_token_quant.default
)

AITER_PER_TOKEN_QUANT_OP = torch.ops.vllm.rocm_aiter_per_token_quant.default
VLLM_PER_TOKEN_QUANT_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default


class AiterRMSFp8GroupQuantPattern:
Expand Down Expand Up @@ -196,7 +204,7 @@ def pattern(
def replacement(
input: torch.Tensor,
):
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
at = FUSED_SILU_MUL_GROUP_QUANT_OP(x=input, group_size=128)
return at[0], at[1]

inputs = [
Expand All @@ -206,6 +214,74 @@ def replacement(
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)


class AiterSiluMulFp8PerTokenQuantPattern(ActivationQuantPattern):
"""
This pattern fuses aiter silu_and_mul and per-token fp8 quant custom
ops into an aiter fused_silu_mul_per_token_quant op.
"""

def __init__(self, quant_op: OpOverload):
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op

def register(self, pm_pass: PatternMatcherPass):
from torch._higher_order_ops.auto_functionalize import auto_functionalized

def pattern(
input: torch.Tensor,
):
at1 = self.silu_and_mul_matcher(input)

d = input.shape[-1] // 2
out_shape = input.shape[:-1] + (d,)
out = torch.empty(out_shape, dtype=FP8_DTYPE, device=input.device)

scale_shape = out_shape[:-1] + (1,)
scale = torch.empty(scale_shape, dtype=torch.float32, device=input.device)

if self.quant_op == AITER_PER_TOKEN_QUANT_OP:
at2 = auto_functionalized(
self.quant_op,
out=out,
x=at1,
scale=scale,
)
return at2[1], at2[2]
else:
at2 = auto_functionalized(
self.quant_op,
result=out,
input=at1,
scale=scale,
scale_ub=None,
)
return at2[1], at2[2]

def replacement(
input: torch.Tensor,
):
d = input.shape[-1] // 2
out_shape = input.shape[:-1] + (d,)
out = torch.empty(out_shape, dtype=FP8_DTYPE, device=input.device)

scale_shape = out_shape[:-1] + (1,)
scales = torch.empty(scale_shape, dtype=torch.float32, device=input.device)

at = auto_functionalized(
FUSED_SILU_MUL_PER_TOKEN_QUANT_OP,
out=out,
scales=scales,
input=input,
)
return at[1], at[2]

inputs = [
self.silu_and_mul_matcher.inputs()[0],
]

pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)


class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
Expand Down Expand Up @@ -240,3 +316,35 @@ def uuid(self):
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)


class RocmAiterSiluMulFp8PerTokenQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses SiLUAndMul with per-token FP8 quantization.
"""

@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_per_token_quant_fusion_pass"
)

# Register patterns for both aiter and vllm per-token quant ops
for quant_op in [AITER_PER_TOKEN_QUANT_OP, VLLM_PER_TOKEN_QUANT_OP]:
AiterSiluMulFp8PerTokenQuantPattern(quant_op).register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self):
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8PerTokenQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
Loading