Skip to content

Commit c719c40

Browse files
[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (vllm-project#29631)
Signed-off-by: elvischenv <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent b08025a commit c719c40

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

vllm/compilation/fix_functionalization.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def __call__(self, graph: torch.fx.Graph):
103103
]:
104104
mutated_args = {1: "result"}
105105
self.defunctionalize(graph, node, mutated_args)
106+
elif (
107+
at_target
108+
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
109+
):
110+
mutated_args = {
111+
1: "allreduce_in",
112+
2: "residual",
113+
3: "norm_out",
114+
4: "quant_out",
115+
5: "scale_out",
116+
}
117+
self.defunctionalize(graph, node, mutated_args)
106118
# For some reason we need to specify the args for both
107119
# silu_and_mul and silu_and_mul_quant. The kwargs
108120
# pathway gets the wrong answer.

vllm/compilation/fx_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def find_op_nodes(
7575
return
7676

7777
assert isinstance(op, OpOverload)
78-
if not op._schema.is_mutable:
79-
yield from graph.find_nodes(op="call_function", target=op)
78+
79+
yield from graph.find_nodes(op="call_function", target=op)
8080

8181
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
8282
if n.args[0] == op:

0 commit comments

Comments
 (0)