Skip to content

Commit 6569576

Browse files
pytorchboteellison
andauthored
Dont exclude constant_pad_nd in prologue fusion (pytorch#150145)
Dont exclude constant_pad_nd in prologue fusion (pytorch#149947) Originally, I excluded constant_pad_nd from fusing to be conservative on compilation time. But, on benchmarking, you do occasionally get speedups by fusing it. Also includes a fix for making single, contiguous dep for prologues. For instance, the following benchmark gets a 7% speedup by fusing in the constant_pad_nd. ``` import torch import torch.nn.functional as F torch._inductor.config.force_disable_caches = True padded_N = 2048 n_pad_rows = 100 K, N = 2048, 4096 tensor1 = torch.randn(padded_N - n_pad_rows, 4096, device="cuda").to(torch.bfloat16) tensor2 = torch.randn(4096, 4096, device="cuda").to(torch.bfloat16) @torch.compile(mode='max-autotune-no-cudagraphs') def masked_linear(input, weight, n_pad_input_rows): """ Linear layer with input padded by `n_pad_input_rows` rows """ # Use constant_pad_nd to pad with zeros for the invalid rows padded_input = F.pad(tensor1, (0, 0, 0, n_pad_input_rows), "constant", 0) return F.linear(padded_input, weight) # Invoke the function masked_linear(tensor1, tensor2, n_pad_rows) ``` Pull Request resolved: pytorch#149947 Approved by: https://github.com/drisspg (cherry picked from commit 4c57aec) Co-authored-by: eellison <[email protected]>
1 parent 5416dff commit 6569576

File tree

4 files changed

+57
-30
lines changed

4 files changed

+57
-30
lines changed

test/inductor/test_max_autotune.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,21 +1646,32 @@ def foo(x, y, z):
16461646
@skipIfXpu
16471647
@config.patch(shape_padding=True)
16481648
@config.patch(force_shape_pad=True)
1649-
@parametrize("sizes", ((250, 245, 128), (250, 256, 128), (256, 128, 62)))
1650-
def test_prologue_masked_load(self, sizes):
1651-
M, K, N = sizes
1652-
1649+
def test_prologue_masked_load(self):
16531650
def foo(x, y):
1654-
return x @ y
1651+
return x @ y.T
16551652

16561653
x = torch.rand([250, 245], device=GPU_TYPE)
1657-
y = torch.rand([245, 128], device=GPU_TYPE)
1654+
y = torch.rand([245, 128], device=GPU_TYPE).T.contiguous()
16581655

16591656
# we should not attempt prologue fusion if it turns an aligned load
16601657
# into an unaligned load
16611658
out, code = run_and_get_code(torch.compile(foo), x, y)
16621659
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1663-
self.check_code(code[0], num_kernels=3, num_allocs=3, num_deallocs=4)
1660+
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1661+
1662+
def test_masked_numeric(self):
1663+
# correctly detect upcast inside the cat mask, dont fuse
1664+
def foo(a, b, y):
1665+
return torch.cat([a, (b * 4)]) @ y.T
1666+
1667+
a = torch.rand([220, 245], device=GPU_TYPE, dtype=torch.float16)
1668+
b = torch.rand([20, 245], device=GPU_TYPE, dtype=torch.float16)
1669+
y = torch.rand([245, 128], device=GPU_TYPE, dtype=torch.float16).T.contiguous()
1670+
1671+
out, code = run_and_get_code(torch.compile(foo), a, b, y)
1672+
1673+
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4)
1674+
self.assertEqual(out, foo(a, b, y), atol=0.05, rtol=0.05)
16641675

16651676

16661677
if __name__ == "__main__":

torch/_inductor/analyze_preserves_zero_mask.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import dataclasses
22
import itertools
3-
from typing import Any, Optional, TYPE_CHECKING
3+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
44

55
import sympy
66

77
import torch
88
from torch._inductor import config
99
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
1010
from torch._inductor.index_propagation import SymPyOps, TypedExpr
11+
from torch._prims_common import type_to_dtype
1112

1213
from .ops_handler import DefaultHandler
1314
from .virtualized import StoreMode, V
@@ -109,11 +110,32 @@ def check_bounds(
109110
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
110111
return sympy.S.Zero
111112

113+
def masked(
114+
self,
115+
mask: DTypeContainer,
116+
body: Callable[[], DTypeContainer],
117+
other: DTypeContainer,
118+
) -> DTypeContainer:
119+
return self.where(mask, other, body())
120+
112121
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
122+
def to_constant(c: Union[int, float]) -> DTypeContainer:
123+
return DTypeContainer(type_to_dtype(type(c)), is_scalar=True)
124+
125+
args = tuple(
126+
a if not isinstance(a, (int, float)) else to_constant(a) for a in args
127+
)
128+
kwargs = {
129+
k: v if not isinstance(v, (int, float)) else to_constant(v)
130+
for k, v in kwargs.items()
131+
}
132+
113133
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
114-
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
115-
if name == "constant":
116-
return DTypeContainer(torch.float, is_scalar=True)
134+
is_scalar = all(
135+
not isinstance(v, DTypeContainer) or v.is_scalar
136+
for v in itertools.chain(args, kwargs.values())
137+
)
138+
out = DTypeContainer(out_dtype, is_scalar=is_scalar)
117139

118140
uses_low_prec = any(
119141
isinstance(dtype_cont, DTypeContainer)

torch/_inductor/ir.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4365,7 +4365,19 @@ def dummy(index, rindex): # type: ignore[no-untyped-def]
43654365
)
43664366

43674367
for inp in self.inputs:
4368-
indexer = inp.layout.make_indexer()
4368+
layout = inp.layout
4369+
4370+
# we dont know what the iteration order is of the template,
4371+
# so we just want to make a single, contiguous dependency
4372+
if not layout.is_contiguous():
4373+
layout = FixedLayout(
4374+
device=layout.device,
4375+
dtype=layout.dtype,
4376+
size=layout.size,
4377+
stride=FlexibleLayout.contiguous_strides(layout.size),
4378+
offset=layout.offset,
4379+
)
4380+
indexer = layout.make_indexer()
43694381

43704382
def dummy(index, rindex): # type: ignore[no-untyped-def]
43714383
assert len(rindex) == 0

torch/_inductor/scheduler.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,24 +3460,6 @@ def check_prologue_fusion_heuristics_fusable(
34603460
why("prologue fusion will not increase amount of bytes read in kernel")
34613461
return False
34623462

3463-
# we want to avoid attempting to fuse predictably unprofitable prologues
3464-
# such as increasing the unaligned reads or writes.
3465-
# TODO - would be nice to generalize this, however, we would need more explicit
3466-
# knowledge of memory access patterns in the TritonTemplate in order to know
3467-
# the stride order to check alignment.
3468-
origins = tuple(
3469-
e.target
3470-
for n in prologue_node.get_nodes()
3471-
if n.node is not None
3472-
for e in n.node.get_origins()
3473-
if e.op == "call_function"
3474-
)
3475-
if origins == (torch.ops.aten.constant_pad_nd.default,):
3476-
why(
3477-
"prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads"
3478-
)
3479-
return False
3480-
34813463
def low_prec_fp(dtype: torch.dtype) -> bool:
34823464
return dtype.itemsize <= 2 and dtype.is_floating_point
34833465

0 commit comments

Comments
 (0)