Skip to content

Commit ed4aa44

Browse files
tianrengaopytorchmergebot
authored andcommitted
CustomOp Inline Fusion (pytorch#165952)
Add Inline Fusion Support for Custom Op Autotuning -------------------------------------------------- This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations. ### Usage ```python def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: """Matrix multiply with k-way decomposition.""" ... @torch.library.custom_op("my_lib::matmul_relu", mutates_args={}) def custom_matmul_relu_dk( a: torch.Tensor, b: torch.Tensor, k_splits: int ) -> torch.Tensor: return torch.relu(decompose_k_implementation(a, b, k_splits)) register_custom_op_autotuning( custom_op=custom_matmul_relu_dk, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), CustomOpConfig(k_splits=8), CustomOpConfig(k_splits=32), CustomOpConfig(k_splits=64), ], name="decompose_k_autotuned", input_gen_fns={ "a": lambda fake: torch.randn_like(fake, device='cuda'), "b": lambda fake: torch.randn_like(fake, device='cuda'), } ) ``` ### How It Works Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency. During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens. Performance Results Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes): <img width="782" height="377" alt="Screenshot 2025-11-04 at 12 43 11 AM" src="https://github.com/user-attachments/assets/22131d4c-a8ce-4f55-bdcd-ac758ddad8cd" /> Metric | Result -- | -- Average Speedup vs ATen | 1.28x Max Speedup vs ATen | 1.41x <br class="Apple-interchange-newline"> The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile. <img width="4874" height="3545" alt="image" src="https://github.com/user-attachments/assets/190a1233-412f-4f34-84cd-9b7cb582f504" /> **Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled. -------------- ### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already) FP32: <img width="738" height="357" alt="Screenshot 2025-11-04 at 12 05 08 AM" src="https://github.com/user-attachments/assets/ee421d22-c426-42f2-8dcd-4dcc547d6219" /> FP16: <img width="769" height="403" alt="Screenshot 2025-11-04 at 12 13 49 AM" src="https://github.com/user-attachments/assets/346d1ffc-15af-40b0-9378-cf9b297711c2" /> The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values. #### Usage: Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer. ```python # Define the matmul+relu function def matmul_relu(x, y): return torch.nn.functional.relu(torch.matmul(x, y)) # Compile with inline subgraph fusion enabled @torch.compile def compiled_matmul_relu(x, y): return matmul_relu(x, y) # Reset dynamo to ensure clean compilation torch._dynamo.reset() with config.patch( { "max_autotune": True, # CRITICAL: These two flags enable inline subgraph fusion "benchmark_epilogue_fusion": False, # Must be False for inline fusion! "enable_inline_subgraph_fusion": True, # Enable inline fusion } ): # Compile and run result = compiled_matmul_relu(a, b) torch.cuda.synchronize() ``` Pull Request resolved: pytorch#165952 Approved by: https://github.com/PaulZhang12, https://github.com/eellison
1 parent 9eebda9 commit ed4aa44

File tree

5 files changed

+151
-151
lines changed

5 files changed

+151
-151
lines changed

test/inductor/test_custom_op_autotune.py

Lines changed: 53 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -216,115 +216,6 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
216216
test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}"
217217
)
218218

219-
@skipIfXpu
220-
def test_mlp_custom_op_autotune(self):
221-
"""Test MLP autotuning with method parameter controlling different decomposition variants.
222-
223-
Validates parametric tuning where the same decomposition function uses different
224-
algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights).
225-
"""
226-
test_op_name = f"test_lib::mlp_{id(self)}"
227-
228-
def mlp_variants(
229-
input_tensor: torch.Tensor,
230-
gate_weight: torch.Tensor,
231-
up_weight: torch.Tensor,
232-
down_weight: torch.Tensor,
233-
method: int = 0,
234-
) -> torch.Tensor:
235-
"""MLP implementation with different computational approaches controlled by method parameter."""
236-
237-
if method == 0:
238-
gate_proj = torch.matmul(input_tensor, gate_weight)
239-
up_proj = torch.matmul(input_tensor, up_weight)
240-
gated = torch.relu(gate_proj) * up_proj
241-
return torch.matmul(gated, down_weight)
242-
243-
elif method == 1:
244-
batch_shape = input_tensor.shape[:-1]
245-
hidden_dim = input_tensor.shape[-1]
246-
output_dim = down_weight.shape[-1]
247-
248-
input_2d = input_tensor.view(-1, hidden_dim)
249-
250-
gate_proj = torch.mm(input_2d, gate_weight)
251-
up_proj = torch.mm(input_2d, up_weight)
252-
253-
gated = torch.relu(gate_proj) * up_proj
254-
output_2d = torch.mm(gated, down_weight)
255-
256-
return output_2d.view(*batch_shape, output_dim)
257-
258-
@torch.library.custom_op(test_op_name, mutates_args=())
259-
def test_mlp_op(
260-
input_tensor: torch.Tensor,
261-
gate_weight: torch.Tensor,
262-
up_weight: torch.Tensor,
263-
down_weight: torch.Tensor,
264-
method: int = 0,
265-
) -> torch.Tensor:
266-
return mlp_variants(
267-
input_tensor, gate_weight, up_weight, down_weight, method=method
268-
)
269-
270-
@test_mlp_op.register_fake
271-
def _(
272-
input_tensor: torch.Tensor,
273-
gate_weight: torch.Tensor,
274-
up_weight: torch.Tensor,
275-
down_weight: torch.Tensor,
276-
method: int = 0,
277-
):
278-
return torch.empty(
279-
input_tensor.shape[:-1] + (down_weight.shape[-1],),
280-
device=input_tensor.device,
281-
dtype=input_tensor.dtype,
282-
)
283-
284-
# Use explicit config with method parameter as tuning knob
285-
register_custom_op_autotuning(
286-
test_mlp_op,
287-
configs=[
288-
CustomOpConfig(method=0),
289-
CustomOpConfig(method=1),
290-
],
291-
name="test_mlp_autotuned",
292-
input_gen_fns={
293-
"input_tensor": lambda fake_tensor: torch.randn_like(
294-
fake_tensor, device=self.device
295-
)
296-
* 0.1,
297-
"gate_weight": lambda fake_tensor: torch.randn_like(
298-
fake_tensor, device=self.device
299-
)
300-
* 0.05,
301-
"up_weight": lambda fake_tensor: torch.randn_like(
302-
fake_tensor, device=self.device
303-
)
304-
* 0.05,
305-
"down_weight": lambda fake_tensor: torch.randn_like(
306-
fake_tensor, device=self.device
307-
)
308-
* 0.05,
309-
},
310-
)
311-
312-
# Create test inputs
313-
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
314-
315-
# Test that all method variants produce numerically equivalent results
316-
expected = mlp_variants(
317-
input_tensor, gate_weight, up_weight, down_weight, method=0
318-
)
319-
320-
# Test autotuning
321-
self._run_autotune_test(
322-
test_mlp_op,
323-
(input_tensor, gate_weight, up_weight, down_weight),
324-
expected,
325-
"MLP",
326-
)
327-
328219
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
329220
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
330221
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
@@ -335,12 +226,12 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
335226

336227
@skipIfXpu
337228
def test_decompose_k_custom_op_autotune(self):
338-
"""Test decompose_k autotuning with parametric tuning for k_splits values.
229+
"""Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale).
339230
340-
Validates numerical parameter sweep where k_splits controls how the K dimension
341-
is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]).
231+
Validates that the custom op encapsulates the entire fused operation with parametric
232+
tuning for k_splits values controlling how the K dimension is decomposed.
342233
"""
343-
test_op_name = f"test_lib::decompose_k_{id(self)}"
234+
test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}"
344235

345236
def decompose_k_implementation(
346237
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
@@ -363,19 +254,23 @@ def decompose_k_implementation(
363254
return torch.sum(result, dim=0) # [m, n]
364255

365256
@torch.library.custom_op(test_op_name, mutates_args=())
366-
def test_decompose_k_op(
367-
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
257+
def matmul_relu_epilogue_op(
258+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4
368259
) -> torch.Tensor:
369-
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
370-
return decompose_k_implementation(a, b, k_splits)
371-
372-
@test_decompose_k_op.register_fake
373-
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
260+
"""Matmul with decompose_k + bias + relu + scale (complete epilogue fusion)."""
261+
matmul_result = decompose_k_implementation(a, b, k_splits)
262+
biased = matmul_result + bias
263+
activated = torch.relu(biased)
264+
scaled = activated * 2.0
265+
return scaled
266+
267+
@matmul_relu_epilogue_op.register_fake
268+
def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4):
374269
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
375270

376-
# Register autotuning with different k_splits values using decomposition function
271+
# Register autotuning with different k_splits values
377272
register_custom_op_autotuning(
378-
test_decompose_k_op,
273+
matmul_relu_epilogue_op,
379274
configs=[
380275
CustomOpConfig(k_splits=2),
381276
CustomOpConfig(k_splits=4),
@@ -385,7 +280,7 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
385280
CustomOpConfig(k_splits=64),
386281
CustomOpConfig(k_splits=128),
387282
],
388-
name="test_decompose_k_autotuned",
283+
name="matmul_relu_epilogue_autotuned",
389284
input_gen_fns={
390285
"a": lambda fake_tensor: torch.randn_like(
391286
fake_tensor, device=self.device
@@ -395,12 +290,45 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
395290
fake_tensor, device=self.device
396291
)
397292
* 0.1,
293+
"bias": lambda fake_tensor: torch.randn_like(
294+
fake_tensor, device=self.device
295+
)
296+
* 0.1,
398297
},
399298
)
400299

300+
# Create test inputs
401301
a, b = self._create_decompose_k_inputs()
402-
expected = a @ b
403-
self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK")
302+
bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1
303+
304+
# Compile the model using the custom op
305+
@torch.compile
306+
def test_model(a, b, bias):
307+
return matmul_relu_epilogue_op(a, b, bias)
308+
309+
torch._dynamo.reset()
310+
311+
with config.patch(
312+
max_autotune=True,
313+
benchmark_fusion=True,
314+
):
315+
compiled_result = test_model(a, b, bias)
316+
317+
def reference_model(a, b, bias):
318+
matmul_result = a @ b
319+
biased = matmul_result + bias
320+
activated = torch.relu(biased)
321+
scaled = activated * 2.0
322+
return scaled
323+
324+
expected = reference_model(a, b, bias)
325+
326+
torch.testing.assert_close(
327+
compiled_result,
328+
expected,
329+
rtol=2e-1,
330+
atol=5e-1,
331+
)
404332

405333
@skipIfXpu
406334
def test_multi_parameter_tuning(self):

torch/_inductor/codegen/subgraph.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@
2424
log = logging.getLogger(__name__)
2525

2626

27+
def inline_subgraph_to_ir_nodes(
28+
gm: torch.fx.GraphModule, inputs: list[Any], name: str
29+
) -> Any:
30+
"""Inline a subgraph by converting its FX operations to individual IR nodes.
31+
32+
This converts a subgraph to multiple ComputedBuffer nodes (fusable),
33+
enabling epilogue fusion with subsequent operations.
34+
35+
Returns:
36+
TensorBox containing the final operation result as individual IR nodes
37+
"""
38+
from torch._inductor.lowering import process_subgraph_nodes
39+
40+
return process_subgraph_nodes(gm, inputs)
41+
42+
2743
class SubgraphChoiceCaller(ir.ChoiceCaller):
2844
"""
2945
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
@@ -261,7 +277,14 @@ def make_fx_graph(
261277
# decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
262278
from torch.fx.experimental.proxy_tensor import make_fx
263279

264-
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
280+
from ..decomposition import select_decomp_table
281+
282+
decomposition_table = select_decomp_table()
283+
284+
return make_fx(
285+
functools.partial(decomp, **decomp_kwargs),
286+
decomposition_table=decomposition_table,
287+
)(*args)
265288

266289
# Generate descriptive name for this variant
267290
variant_name = self._generate_variant_name(decomp, decomp_kwargs)

torch/_inductor/kernel/custom_op.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Union
77

88
import torch
9+
from torch._inductor import config
910
from torch._inductor.codegen.subgraph import SubgraphTemplate
1011
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
1112
from torch._inductor.lowering import lowerings, validate_ir
@@ -158,7 +159,6 @@ def _adapt_user_input_gen_fns(
158159
159160
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
160161
"""
161-
from torch._inductor import config
162162

163163
name_to_index = {name: i for i, name in enumerate(arg_names)}
164164
index_based_fns = {}
@@ -238,6 +238,7 @@ def autotune_custom_op(
238238
239239
This function generates multiple implementation choices for a custom operation and
240240
uses Inductor's autotuning system to select the best performing variant at runtime.
241+
After selecting the best choice, applies inline fusion if the winning choice has a graph.
241242
242243
Args:
243244
name: Unique identifier for the autotuning operation
@@ -320,14 +321,34 @@ def autotune_custom_op(
320321
)
321322
input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)
322323

323-
return autotune_select_algorithm(
324+
# Run autotuning and get both result and winning choice
325+
selected_result, winning_choice = autotune_select_algorithm(
324326
name=name,
325327
choices=choices,
326328
input_nodes=list(inputs),
327329
layout=choices[0].layout,
328330
input_gen_fns=input_gen_fns,
331+
return_choice=True,
329332
)
330333

334+
# Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl)
335+
if winning_choice.gm is not None:
336+
log.debug(
337+
"Inlining winning choice: %s (name=%s)",
338+
getattr(winning_choice, "name", type(winning_choice).__name__),
339+
name,
340+
)
341+
from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes
342+
343+
return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)
344+
345+
log.debug(
346+
"Winning choice does not support inlining: %s (name=%s)",
347+
getattr(winning_choice, "name", type(winning_choice).__name__),
348+
name,
349+
)
350+
return selected_result
351+
331352

332353
def register_custom_op_autotuning(
333354
custom_op: torch._library.custom_ops.CustomOpDef,
@@ -360,7 +381,7 @@ def my_attention(query, key, value, head_dim=32):
360381
"query": lambda fake: torch.randn_like(fake, device='cuda'),
361382
"key": lambda fake: torch.randn_like(fake, device='cuda'),
362383
"value": lambda fake: torch.randn_like(fake, device='cuda'),
363-
}
384+
},
364385
)
365386
"""
366387
from torch._library.custom_ops import CustomOpDef
@@ -378,12 +399,12 @@ def my_attention(query, key, value, head_dim=32):
378399
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
379400

380401
processed_configs = []
381-
for config in configs:
382-
if isinstance(config, CustomOpConfig):
383-
processed_configs.append(config)
402+
for cfg in configs:
403+
if isinstance(cfg, CustomOpConfig):
404+
processed_configs.append(cfg)
384405
else:
385406
raise TypeError(
386-
f"Each config must be a CustomOpConfig object, got {type(config)}"
407+
f"Each config must be a CustomOpConfig object, got {type(cfg)}"
387408
)
388409

389410
if not processed_configs:
@@ -402,14 +423,12 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
402423
decompositions = []
403424
non_tensor_args = []
404425

405-
for config in processed_configs:
406-
decomp = config.get_decomposition(default_impl=default_impl)
426+
for cfg in processed_configs:
427+
decomp = cfg.get_decomposition(default_impl=default_impl)
407428
decompositions.append(decomp)
408429

409430
# Merge config params with runtime kwargs (runtime takes precedence)
410-
merged_kwargs = _merge_config_and_runtime_kwargs(
411-
config.params, runtime_kwargs
412-
)
431+
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
413432
non_tensor_args.append(merged_kwargs)
414433

415434
result = autotune_custom_op(

0 commit comments

Comments
 (0)