From 4ee35bb2fb7c2cd14b8f518ffe4d3b167ad1ff07 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 24 Sep 2025 01:13:38 +0800 Subject: [PATCH 01/12] Update activation_quant_fusion.py --- vllm/compilation/activation_quant_fusion.py | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b02d..2223fe8af5d0 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -38,6 +38,32 @@ def silu_mul_replacement_static(result: torch.Tensor, return at[1] +def silu_mul_mxfp4_gemm_pattern(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default, + result=result, + x=at1[1], + weight=weight, + weight_scale=scale, + x_scales=None) + return at2[1] + + +def silu_mul_mxfp4_gemm_replacement(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default, + result=result, + x=input, + weight=weight, + weight_scale=scale) + return at[1] + + def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -51,6 +77,10 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +def empty_fp4(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda") + + class ActivationQuantFusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -76,6 +106,17 @@ def __init__(self, config: VllmConfig): register_replacement(silu_mul_pattern_static, silu_mul_replacement_static, inputs, fwd_only, self.patterns) + + inputs = [ + empty_bf16(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp4(4, 8), # weight + empty_fp4(1, 1), # scale + ] + register_replacement(silu_mul_mxfp4_gemm_pattern, + silu_mul_mxfp4_gemm_replacement, inputs, fwd_only, + self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() From 22e36f5ab0381e3c4299ba23652840abbda2e227 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 24 Sep 2025 01:14:55 +0800 Subject: [PATCH 02/12] Update quark_w4a4_mxfp4.py --- .../quark/schemes/quark_w4a4_mxfp4.py | 64 +++++++++++++------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 94c0698eb50c..a88b8f779b75 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -18,74 +18,94 @@ from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant + from aiter.ops.triton.activation import act_mul_and_mxfp4_quant from vllm.utils import direct_register_custom_op if envs.VLLM_TRITON_FP4_GEMM_USE_ASM: from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_with_dynamic_quant( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, x_scales: torch.Tensor = None, - out_dtype: Optional[torch.dtype] = torch.bfloat16, - ) -> torch.Tensor: - M = x.shape[0] + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: if envs.VLLM_TRITON_FP4_GEMM_USE_ASM: + M = x.shape[0] if x_scales is None: # use hip quant kernel for performance x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) else: x_q = x x_s = x_scales - # 32 alignment is enough for dim0 padding of output for # gemm_a4w4 kernel y = torch.empty((M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype) - gemm_a4w4(x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True) - return y[:M] + result.copy_(y[:M]) else: if x_scales is None: x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x x_s = x_scales - y = torch.empty(x_q.shape[0], - weight.shape[0], - device=x_q.device, - dtype=out_dtype) - - gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) - return y + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, result) def gemm_with_dynamic_quant_fake( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, x_scales: torch.Tensor = None, - out_dtype: Optional[torch.dtype] = torch.bfloat16, - ) -> torch.Tensor: - return torch.empty((*x.shape[:-1], weight.shape[0]), - dtype=out_dtype, - device=x.device) + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return direct_register_custom_op( op_name="gemm_with_dynamic_quant", op_func=gemm_with_dynamic_quant, - mutates_args=[], + mutates_args=['result'], fake_impl=gemm_with_dynamic_quant_fake, dispatch_key=current_platform.dispatch_key, ) + def silu_and_mul_mxfp4_gemm( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + x_fp4, blockscale_e8m0 = act_mul_and_mxfp4_quant(x, 'silu') + gemm_with_dynamic_quant(result, x_fp4, weight, weight_scale, blockscale_e8m0, out_dtype) + + def silu_and_mul_mxfp4_gemm_fake( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return + + direct_register_custom_op( + op_name="silu_and_mul_mxfp4_gemm", + op_func=silu_and_mul_mxfp4_gemm, + mutates_args=['result'], + fake_impl=silu_and_mul_mxfp4_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + except ImportError: dynamic_mxfp4_quant = gemm_afp4wfp4 = None @@ -225,5 +245,7 @@ def apply_weights(self, return F.linear(x, dq_w, bias) else: - return torch.ops.vllm.gemm_with_dynamic_quant( - x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype) + result = torch.empty((*x.shape[:-1], layer.weight.shape[0]), dtype=self.out_dtype, device=x.device) + torch.ops.vllm.gemm_with_dynamic_quant( + result, x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype) + return result From b41c882fd50adba7f27a02468b721e79017be3b9 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 24 Sep 2025 11:34:07 +0800 Subject: [PATCH 03/12] Refine example inputs --- vllm/compilation/activation_quant_fusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 2223fe8af5d0..4dd2127d25dd 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -108,11 +108,11 @@ def __init__(self, config: VllmConfig): self.patterns) inputs = [ - empty_bf16(5, 4), # result - empty_bf16(5, 4), # result_silu_mul - empty_bf16(5, 4), # input - empty_fp4(4, 8), # weight - empty_fp4(1, 1), # scale + empty_bf16(32, 32), # result + empty_bf16(32, 32), # result_silu_mul + empty_bf16(32, 32), # input + empty_fp4(32, 32), # weight + empty_fp4(32, 1), # scale ] register_replacement(silu_mul_mxfp4_gemm_pattern, silu_mul_mxfp4_gemm_replacement, inputs, fwd_only, From 933c6b52c19f5ed347f71bb2398522a37470c971 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 00:46:27 +0800 Subject: [PATCH 04/12] Update pass_manager.py --- vllm/compilation/pass_manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e07e52be9fdf..d28a018140b1 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -14,6 +14,9 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass +if current_platform.is_rocm(): + from .rocm_fusion import ROCmFusionPass + from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context @@ -64,6 +67,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_fusion: self.passes += [FusionPass.instance(config)] self.passes += [ActivationQuantFusionPass(config)] + if current_platform.is_rocm(): + self.passes += [ROCmFusionPass(config)] if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] From a900b295a82376652eb092982815c35475769cf2 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 00:47:27 +0800 Subject: [PATCH 05/12] Update quark_w4a4_mxfp4.py --- .../quark/schemes/quark_w4a4_mxfp4.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index a88b8f779b75..8ab2aa135db3 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -15,10 +15,12 @@ from vllm.platforms import current_platform try: + import triton from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.ops.triton.activation import act_mul_and_mxfp4_quant + from aiter.ops.triton.fused_mxfp4_quant import _fused_rms_mxfp4_quant_kernel from vllm.utils import direct_register_custom_op if envs.VLLM_TRITON_FP4_GEMM_USE_ASM: @@ -106,6 +108,69 @@ def silu_and_mul_mxfp4_gemm_fake( dispatch_key=current_platform.dispatch_key, ) + def add_rmsnorm_mxfp4_gemm( + result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + MXFP4_QUANT_BLOCK_SIZE = 32 + M, N1 = input.shape + BLOCK_SIZE = max(triton.next_power_of_2(N1), MXFP4_QUANT_BLOCK_SIZE) + BLOCK_SIZE = max(BLOCK_SIZE, MXFP4_QUANT_BLOCK_SIZE) + res_row_stride = residual.stride(0) + out_res_row_stride = residual_out.stride(0) + rms_out_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=input.device) + rms_out_bs = torch.empty( + ((N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE, M), + dtype=torch.uint8, + device=input.device, + ).T + _fused_rms_mxfp4_quant_kernel[(M,)]( + input, + weight_rms, + None, + None, + residual, + rms_out_fp4, + rms_out_bs, + None, + residual_out, + epsilon, + 0.0, + M, + N1, + 0, + input.stride(0), + 0, + res_row_stride, + rms_out_fp4.stride(0), + *rms_out_bs.stride(), + 0, + out_res_row_stride, + BLOCK_SIZE=BLOCK_SIZE, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + SKIP_SECOND_INPUT=True, + FIRST_INPUT_RES=True, + ) + gemm_with_dynamic_quant(result, rms_out_fp4, weight_gemm, scale, rms_out_bs, out_dtype) + + def add_rmsnorm_mxfp4_gemm_fake( + result: torch.Tensor, input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor, epsilon: float, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return + + direct_register_custom_op( + op_name="add_rmsnorm_mxfp4_gemm", + op_func=add_rmsnorm_mxfp4_gemm, + mutates_args=['result', 'residual_out'], + fake_impl=add_rmsnorm_mxfp4_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + except ImportError: dynamic_mxfp4_quant = gemm_afp4wfp4 = None From 27ca9ea919d8faa14432af86df67872db13e922e Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 00:48:15 +0800 Subject: [PATCH 06/12] Create rocm_fusion.py --- vllm/compilation/rocm_fusion.py | 252 ++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 vllm/compilation/rocm_fusion.py diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py new file mode 100644 index 000000000000..789d08b92d46 --- /dev/null +++ b/vllm/compilation/rocm_fusion.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable + +import torch +from torch._ops import OpOverload +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, + register_replacement, Match) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass + + +logger = init_logger(__name__) + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = current_platform.fp8_dtype() + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +def empty_fp4(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda") + + +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: Match, fused_op): + super().__init__(match) + assert isinstance(fused_op, OpOverload) + self.FUSED_OP = fused_op # in-place fused quant op + + def insert_fused_node(self, fused_return_mapping: dict[int, tuple[torch.fx.Node, int]], **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence, others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + +ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default +QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default + + +class AddRMSNormMXFP4GemmPattern: + def __init__(self, epsilon: float): + self.epsilon = epsilon + self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default + + def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): + + def pattern( + result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor): + at1 = auto_functionalized(ADD_RMS_OP, + result=result_rms, + input=input, + residual_out=residual_out, + residual=residual, + weight=weight_rms, + epsilon=self.epsilon) + at2 = auto_functionalized(QUANT_F4GEMM_OP, + result=result, + x=at1[1], + weight=weight_gemm, + weight_scale=scale, + x_scales=None) + return at2[1], at1[2] + + def replacement( + result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, residual_out: torch.Tensor, + residual: torch.Tensor, weight_rms: torch.Tensor, + weight_gemm: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual_out=residual_out, + residual=residual, + weight_rms=weight_rms, + weight_gemm=weight_gemm, + scale=scale, + epsilon=self.epsilon) + return at[1], at[2] + + inputs = [ + empty_bf16(32, 32), # result + empty_bf16(32, 32), # result_rms + empty_bf16(32, 32), # input + empty_bf16(32, 32), # residual_out + empty_bf16(32, 32), # residual + empty_bf16(1, 32), # weight_rms + empty_fp4(32, 32), # weight_gemm + empty_fp4(32, 1), # scale + ] + + register_replacement( + pattern, + replacement, + inputs, + fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + add_rms_node = self.find_auto_fn(ADD_RMS_OP) + quant_f4gemm_node = self.find_auto_fn(QUANT_F4GEMM_OP) + + assert len(add_rms_node.users) == 2 + assert len(quant_f4gemm_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_fn node returns a tuple of (None, result, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + del kwargs["result_rms"] # not used in the fused op + # 0 is always None + fused_return_mapping = {1: (quant_f4gemm_node, 1), 2: (add_rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + **kwargs, + epsilon=add_rms_node.kwargs["epsilon"]) + + +class ROCmFusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.matches: list[MultiOutputMatch] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_fusion_pass") + + for epsilon in [1e-5, 1e-6]: + AddRMSNormMXFP4GemmPattern(epsilon).register( + self.patterns, self.record_match) + + def record_match(self, match: MultiOutputMatch) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def process_matches(self, graph: torch.fx.Graph): + """ + Manually process multi-output matches and replace them with fused nodes. + See MultiOutputMatch for more details. + """ + for match in self.matches: + match.process() + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.match.nodes) + + def __call__(self, graph: torch.fx.Graph): + logger.info(graph) + + self.begin() + self.dump_graph(graph, "before_rocm_fusion") + + count = self.patterns.apply(graph) + logger.info("Replaced %s patterns", count) + self.dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + self.process_matches(graph) + logger.info("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_rocm_fusion") + self.matches.clear() + self.end_and_log() From c191a912e21587cbb719cb1b719fb62944bccd0a Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 00:52:50 +0800 Subject: [PATCH 07/12] Update rocm_fusion.py: refine code --- vllm/compilation/rocm_fusion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py index 789d08b92d46..5b62f7b102f0 100644 --- a/vllm/compilation/rocm_fusion.py +++ b/vllm/compilation/rocm_fusion.py @@ -235,18 +235,16 @@ def process_matches(self, graph: torch.fx.Graph): for node in match.match.nodes) def __call__(self, graph: torch.fx.Graph): - logger.info(graph) - self.begin() self.dump_graph(graph, "before_rocm_fusion") count = self.patterns.apply(graph) - logger.info("Replaced %s patterns", count) + logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_pattern_match") # Manually process multi-output matches (and run DCE) self.process_matches(graph) - logger.info("Post-processed %s matches", len(self.matches)) + logger.debug("Post-processed %s matches", len(self.matches)) self.dump_graph(graph, "after_rocm_fusion") self.matches.clear() self.end_and_log() From d00ae81284dc0cf8e31e0f2d08c189fff4a036fc Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 11:42:47 +0800 Subject: [PATCH 08/12] Update rocm_fusion.py: bugfix --- vllm/compilation/rocm_fusion.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py index 5b62f7b102f0..659d35f60a08 100644 --- a/vllm/compilation/rocm_fusion.py +++ b/vllm/compilation/rocm_fusion.py @@ -40,9 +40,11 @@ def empty_fp4(*args, **kwargs): class QuantMultiOutputMatch(MultiOutputMatch): - def __init__(self, match: Match, fused_op): + def __init__(self, match: Match, quant_op, fused_op): super().__init__(match) + # assert isinstance(quant_op, OpOverload) assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op self.FUSED_OP = fused_op # in-place fused quant op def insert_fused_node(self, fused_return_mapping: dict[int, tuple[torch.fx.Node, int]], **kwargs): @@ -97,13 +99,13 @@ def insert_fused_node(self, fused_return_mapping: dict[int, tuple[torch.fx.Node, ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default -QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default class AddRMSNormMXFP4GemmPattern: def __init__(self, epsilon: float): self.epsilon = epsilon self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default + self.QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -119,7 +121,7 @@ def pattern( residual=residual, weight=weight_rms, epsilon=self.epsilon) - at2 = auto_functionalized(QUANT_F4GEMM_OP, + at2 = auto_functionalized(self.QUANT_F4GEMM_OP, result=result, x=at1[1], weight=weight_gemm, @@ -161,14 +163,14 @@ def replacement( fwd_only, pm_pass, extra_check=lambda m: record_match( - self.Match(m, self.FUSED_OP))) + self.Match(m, self.QUANT_F4GEMM_OP, self.FUSED_OP))) class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind add_rms_node = self.find_auto_fn(ADD_RMS_OP) - quant_f4gemm_node = self.find_auto_fn(QUANT_F4GEMM_OP) + quant_f4gemm_node = self.find_auto_fn(self.QUANT_OP) assert len(add_rms_node.users) == 2 assert len(quant_f4gemm_node.users) == 1 From a500bb44ddabd393ff5fe736afbc5d248e76e139 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 14:48:22 +0800 Subject: [PATCH 09/12] Update rocm_fusion.py: bugfix: do not use large example inputs --- vllm/compilation/rocm_fusion.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py index 659d35f60a08..6d8fc93e0208 100644 --- a/vllm/compilation/rocm_fusion.py +++ b/vllm/compilation/rocm_fusion.py @@ -42,7 +42,7 @@ class QuantMultiOutputMatch(MultiOutputMatch): def __init__(self, match: Match, quant_op, fused_op): super().__init__(match) - # assert isinstance(quant_op, OpOverload) + assert isinstance(quant_op, OpOverload) assert isinstance(fused_op, OpOverload) self.QUANT_OP = quant_op # in-place quant op self.FUSED_OP = fused_op # in-place fused quant op @@ -146,14 +146,14 @@ def replacement( return at[1], at[2] inputs = [ - empty_bf16(32, 32), # result - empty_bf16(32, 32), # result_rms - empty_bf16(32, 32), # input - empty_bf16(32, 32), # residual_out - empty_bf16(32, 32), # residual + empty_bf16(32, 4), # result + empty_bf16(32, 4), # result_rms + empty_bf16(32, 4), # input + empty_bf16(32, 4), # residual_out + empty_bf16(32, 4), # residual empty_bf16(1, 32), # weight_rms - empty_fp4(32, 32), # weight_gemm - empty_fp4(32, 1), # scale + empty_fp4(32, 4), # weight_gemm + empty_fp4(1, 1), # scale ] register_replacement( @@ -214,6 +214,10 @@ def __init__(self, config: VllmConfig): for epsilon in [1e-5, 1e-6]: AddRMSNormMXFP4GemmPattern(epsilon).register( self.patterns, self.record_match) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() def record_match(self, match: MultiOutputMatch) -> bool: # Hijack the extra_check to record the match and From 995289731c1e810a6b1576ca452eb59cf14d5cee Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 16:45:43 +0800 Subject: [PATCH 10/12] Update inductor_pass.py: Add fake_mode --- vllm/compilation/inductor_pass.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 2a149c65b387..51359c96d5b1 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import inspect import json import types from contextlib import contextmanager from typing import Any, Callable, Optional, Union +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) import torch from torch import fx @@ -114,3 +117,20 @@ def __call__(self, graph: torch.fx.Graph): def uuid(self) -> Any: return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing( + None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new From b85431d6f8215280c676b9391ed7b814609eec26 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 17:02:44 +0800 Subject: [PATCH 11/12] Update rocm_fusion.py: Enable Fake --- vllm/compilation/rocm_fusion.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/rocm_fusion.py b/vllm/compilation/rocm_fusion.py index 6d8fc93e0208..35f7c4800636 100644 --- a/vllm/compilation/rocm_fusion.py +++ b/vllm/compilation/rocm_fusion.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .fx_utils import find_getitem_maybe from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass @@ -146,13 +147,13 @@ def replacement( return at[1], at[2] inputs = [ - empty_bf16(32, 4), # result - empty_bf16(32, 4), # result_rms - empty_bf16(32, 4), # input - empty_bf16(32, 4), # residual_out - empty_bf16(32, 4), # residual - empty_bf16(1, 32), # weight_rms - empty_fp4(32, 4), # weight_gemm + empty_bf16(4, 4), # result + empty_bf16(4, 4), # result_rms + empty_bf16(4, 4), # input + empty_bf16(4, 4), # residual_out + empty_bf16(4, 4), # residual + empty_bf16(1, 4), # weight_rms + empty_fp4(4, 4), # weight_gemm empty_fp4(1, 1), # scale ] @@ -204,6 +205,7 @@ class ROCmFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) From 1bf7f33d90b368fa93a327401d3c0a6798328007 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 25 Sep 2025 17:03:42 +0800 Subject: [PATCH 12/12] Update activation_quant_fusion.py: Enable Fake --- vllm/compilation/activation_quant_fusion.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 4dd2127d25dd..2d60556974f5 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -91,6 +92,7 @@ class ActivationQuantFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -108,11 +110,11 @@ def __init__(self, config: VllmConfig): self.patterns) inputs = [ - empty_bf16(32, 32), # result - empty_bf16(32, 32), # result_silu_mul - empty_bf16(32, 32), # input - empty_fp4(32, 32), # weight - empty_fp4(32, 1), # scale + empty_bf16(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp4(5, 4), # weight + empty_fp4(1, 1), # scale ] register_replacement(silu_mul_mxfp4_gemm_pattern, silu_mul_mxfp4_gemm_replacement, inputs, fwd_only,