Commit 6569576
Dont exclude constant_pad_nd in prologue fusion (pytorch#150145)
Dont exclude constant_pad_nd in prologue fusion (pytorch#149947)
Originally, I excluded constant_pad_nd from fusing to be conservative on compilation time. But, on benchmarking, you do occasionally get speedups by fusing it. Also includes a fix for making single, contiguous dep for prologues.
For instance, the following benchmark gets a 7% speedup by fusing in the constant_pad_nd.
```
import torch
import torch.nn.functional as F
torch._inductor.config.force_disable_caches = True
padded_N = 2048
n_pad_rows = 100
K, N = 2048, 4096
tensor1 = torch.randn(padded_N - n_pad_rows, 4096, device="cuda").to(torch.bfloat16)
tensor2 = torch.randn(4096, 4096, device="cuda").to(torch.bfloat16)
@torch.compile(mode='max-autotune-no-cudagraphs')
def masked_linear(input, weight, n_pad_input_rows):
"""
Linear layer with input padded by `n_pad_input_rows` rows
"""
# Use constant_pad_nd to pad with zeros for the invalid rows
padded_input = F.pad(tensor1, (0, 0, 0, n_pad_input_rows), "constant", 0)
return F.linear(padded_input, weight)
# Invoke the function
masked_linear(tensor1, tensor2, n_pad_rows)
```
Pull Request resolved: pytorch#149947
Approved by: https://github.com/drisspg
(cherry picked from commit 4c57aec)
Co-authored-by: eellison <[email protected]>1 parent 5416dff commit 6569576
File tree
4 files changed
+57
-30
lines changed- test/inductor
- torch/_inductor
4 files changed
+57
-30
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1646 | 1646 | | |
1647 | 1647 | | |
1648 | 1648 | | |
1649 | | - | |
1650 | | - | |
1651 | | - | |
1652 | | - | |
| 1649 | + | |
1653 | 1650 | | |
1654 | | - | |
| 1651 | + | |
1655 | 1652 | | |
1656 | 1653 | | |
1657 | | - | |
| 1654 | + | |
1658 | 1655 | | |
1659 | 1656 | | |
1660 | 1657 | | |
1661 | 1658 | | |
1662 | 1659 | | |
1663 | | - | |
| 1660 | + | |
| 1661 | + | |
| 1662 | + | |
| 1663 | + | |
| 1664 | + | |
| 1665 | + | |
| 1666 | + | |
| 1667 | + | |
| 1668 | + | |
| 1669 | + | |
| 1670 | + | |
| 1671 | + | |
| 1672 | + | |
| 1673 | + | |
| 1674 | + | |
1664 | 1675 | | |
1665 | 1676 | | |
1666 | 1677 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
| 3 | + | |
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
109 | 110 | | |
110 | 111 | | |
111 | 112 | | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
112 | 121 | | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
113 | 133 | | |
114 | | - | |
115 | | - | |
116 | | - | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
117 | 139 | | |
118 | 140 | | |
119 | 141 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4365 | 4365 | | |
4366 | 4366 | | |
4367 | 4367 | | |
4368 | | - | |
| 4368 | + | |
| 4369 | + | |
| 4370 | + | |
| 4371 | + | |
| 4372 | + | |
| 4373 | + | |
| 4374 | + | |
| 4375 | + | |
| 4376 | + | |
| 4377 | + | |
| 4378 | + | |
| 4379 | + | |
| 4380 | + | |
4369 | 4381 | | |
4370 | 4382 | | |
4371 | 4383 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3460 | 3460 | | |
3461 | 3461 | | |
3462 | 3462 | | |
3463 | | - | |
3464 | | - | |
3465 | | - | |
3466 | | - | |
3467 | | - | |
3468 | | - | |
3469 | | - | |
3470 | | - | |
3471 | | - | |
3472 | | - | |
3473 | | - | |
3474 | | - | |
3475 | | - | |
3476 | | - | |
3477 | | - | |
3478 | | - | |
3479 | | - | |
3480 | | - | |
3481 | 3463 | | |
3482 | 3464 | | |
3483 | 3465 | | |
| |||
0 commit comments