Commit ed4aa44
CustomOp Inline Fusion (pytorch#165952)
Add Inline Fusion Support for Custom Op Autotuning
--------------------------------------------------
This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations.
### Usage
```python
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition."""
...
@torch.library.custom_op("my_lib::matmul_relu", mutates_args={})
def custom_matmul_relu_dk(
a: torch.Tensor, b: torch.Tensor, k_splits: int
) -> torch.Tensor:
return torch.relu(decompose_k_implementation(a, b, k_splits))
register_custom_op_autotuning(
custom_op=custom_matmul_relu_dk,
configs=[
CustomOpConfig(k_splits=2),
CustomOpConfig(k_splits=4),
CustomOpConfig(k_splits=8),
CustomOpConfig(k_splits=32),
CustomOpConfig(k_splits=64),
],
name="decompose_k_autotuned",
input_gen_fns={
"a": lambda fake: torch.randn_like(fake, device='cuda'),
"b": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
```
### How It Works
Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency.
During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens.
Performance Results
Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes):
<img width="782" height="377" alt="Screenshot 2025-11-04 at 12 43 11 AM" src="https://github.com/user-attachments/assets/22131d4c-a8ce-4f55-bdcd-ac758ddad8cd" />
Metric | Result
-- | --
Average Speedup vs ATen | 1.28x
Max Speedup vs ATen | 1.41x
<br class="Apple-interchange-newline">
The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile.
<img width="4874" height="3545" alt="image" src="https://github.com/user-attachments/assets/190a1233-412f-4f34-84cd-9b7cb582f504" />
**Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled.
--------------
### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already)
FP32:
<img width="738" height="357" alt="Screenshot 2025-11-04 at 12 05 08 AM" src="https://github.com/user-attachments/assets/ee421d22-c426-42f2-8dcd-4dcc547d6219" />
FP16:
<img width="769" height="403" alt="Screenshot 2025-11-04 at 12 13 49 AM" src="https://github.com/user-attachments/assets/346d1ffc-15af-40b0-9378-cf9b297711c2" />
The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values.
#### Usage:
Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer.
```python
# Define the matmul+relu function
def matmul_relu(x, y):
return torch.nn.functional.relu(torch.matmul(x, y))
# Compile with inline subgraph fusion enabled
@torch.compile
def compiled_matmul_relu(x, y):
return matmul_relu(x, y)
# Reset dynamo to ensure clean compilation
torch._dynamo.reset()
with config.patch(
{
"max_autotune": True,
# CRITICAL: These two flags enable inline subgraph fusion
"benchmark_epilogue_fusion": False, # Must be False for inline fusion!
"enable_inline_subgraph_fusion": True, # Enable inline fusion
}
):
# Compile and run
result = compiled_matmul_relu(a, b)
torch.cuda.synchronize()
```
Pull Request resolved: pytorch#165952
Approved by: https://github.com/PaulZhang12, https://github.com/eellison1 parent 9eebda9 commit ed4aa44
File tree
5 files changed
+151
-151
lines changed- test/inductor
- torch/_inductor
- codegen
- kernel
5 files changed
+151
-151
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
216 | 216 | | |
217 | 217 | | |
218 | 218 | | |
219 | | - | |
220 | | - | |
221 | | - | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
296 | | - | |
297 | | - | |
298 | | - | |
299 | | - | |
300 | | - | |
301 | | - | |
302 | | - | |
303 | | - | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | 219 | | |
329 | 220 | | |
330 | 221 | | |
| |||
335 | 226 | | |
336 | 227 | | |
337 | 228 | | |
338 | | - | |
| 229 | + | |
339 | 230 | | |
340 | | - | |
341 | | - | |
| 231 | + | |
| 232 | + | |
342 | 233 | | |
343 | | - | |
| 234 | + | |
344 | 235 | | |
345 | 236 | | |
346 | 237 | | |
| |||
363 | 254 | | |
364 | 255 | | |
365 | 256 | | |
366 | | - | |
367 | | - | |
| 257 | + | |
| 258 | + | |
368 | 259 | | |
369 | | - | |
370 | | - | |
371 | | - | |
372 | | - | |
373 | | - | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
374 | 269 | | |
375 | 270 | | |
376 | | - | |
| 271 | + | |
377 | 272 | | |
378 | | - | |
| 273 | + | |
379 | 274 | | |
380 | 275 | | |
381 | 276 | | |
| |||
385 | 280 | | |
386 | 281 | | |
387 | 282 | | |
388 | | - | |
| 283 | + | |
389 | 284 | | |
390 | 285 | | |
391 | 286 | | |
| |||
395 | 290 | | |
396 | 291 | | |
397 | 292 | | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
398 | 297 | | |
399 | 298 | | |
400 | 299 | | |
| 300 | + | |
401 | 301 | | |
402 | | - | |
403 | | - | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
404 | 332 | | |
405 | 333 | | |
406 | 334 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
27 | 43 | | |
28 | 44 | | |
29 | 45 | | |
| |||
261 | 277 | | |
262 | 278 | | |
263 | 279 | | |
264 | | - | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
265 | 288 | | |
266 | 289 | | |
267 | 290 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
| |||
158 | 159 | | |
159 | 160 | | |
160 | 161 | | |
161 | | - | |
162 | 162 | | |
163 | 163 | | |
164 | 164 | | |
| |||
238 | 238 | | |
239 | 239 | | |
240 | 240 | | |
| 241 | + | |
241 | 242 | | |
242 | 243 | | |
243 | 244 | | |
| |||
320 | 321 | | |
321 | 322 | | |
322 | 323 | | |
323 | | - | |
| 324 | + | |
| 325 | + | |
324 | 326 | | |
325 | 327 | | |
326 | 328 | | |
327 | 329 | | |
328 | 330 | | |
| 331 | + | |
329 | 332 | | |
330 | 333 | | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
331 | 352 | | |
332 | 353 | | |
333 | 354 | | |
| |||
360 | 381 | | |
361 | 382 | | |
362 | 383 | | |
363 | | - | |
| 384 | + | |
364 | 385 | | |
365 | 386 | | |
366 | 387 | | |
| |||
378 | 399 | | |
379 | 400 | | |
380 | 401 | | |
381 | | - | |
382 | | - | |
383 | | - | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
384 | 405 | | |
385 | 406 | | |
386 | | - | |
| 407 | + | |
387 | 408 | | |
388 | 409 | | |
389 | 410 | | |
| |||
402 | 423 | | |
403 | 424 | | |
404 | 425 | | |
405 | | - | |
406 | | - | |
| 426 | + | |
| 427 | + | |
407 | 428 | | |
408 | 429 | | |
409 | 430 | | |
410 | | - | |
411 | | - | |
412 | | - | |
| 431 | + | |
413 | 432 | | |
414 | 433 | | |
415 | 434 | | |
| |||
0 commit comments