Skip to content

Commit a6321d6

Browse files
authored
Revert "Dont exclude constant_pad_nd in prologue fusion" (pytorch#150699)
Revert "Dont exclude constant_pad_nd in prologue fusion (pytorch#150145)" This reverts commit 6569576.
1 parent 1cc51c6 commit a6321d6

File tree

4 files changed

+30
-57
lines changed

4 files changed

+30
-57
lines changed

test/inductor/test_max_autotune.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,32 +1646,21 @@ def foo(x, y, z):
16461646
@skipIfXpu
16471647
@config.patch(shape_padding=True)
16481648
@config.patch(force_shape_pad=True)
1649-
def test_prologue_masked_load(self):
1649+
@parametrize("sizes", ((250, 245, 128), (250, 256, 128), (256, 128, 62)))
1650+
def test_prologue_masked_load(self, sizes):
1651+
M, K, N = sizes
1652+
16501653
def foo(x, y):
1651-
return x @ y.T
1654+
return x @ y
16521655

16531656
x = torch.rand([250, 245], device=GPU_TYPE)
1654-
y = torch.rand([245, 128], device=GPU_TYPE).T.contiguous()
1657+
y = torch.rand([245, 128], device=GPU_TYPE)
16551658

16561659
# we should not attempt prologue fusion if it turns an aligned load
16571660
# into an unaligned load
16581661
out, code = run_and_get_code(torch.compile(foo), x, y)
16591662
self.assertEqual(out, foo(x, y), atol=0.05, rtol=0.05)
1660-
self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=2)
1661-
1662-
def test_masked_numeric(self):
1663-
# correctly detect upcast inside the cat mask, dont fuse
1664-
def foo(a, b, y):
1665-
return torch.cat([a, (b * 4)]) @ y.T
1666-
1667-
a = torch.rand([220, 245], device=GPU_TYPE, dtype=torch.float16)
1668-
b = torch.rand([20, 245], device=GPU_TYPE, dtype=torch.float16)
1669-
y = torch.rand([245, 128], device=GPU_TYPE, dtype=torch.float16).T.contiguous()
1670-
1671-
out, code = run_and_get_code(torch.compile(foo), a, b, y)
1672-
1673-
self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4)
1674-
self.assertEqual(out, foo(a, b, y), atol=0.05, rtol=0.05)
1663+
self.check_code(code[0], num_kernels=3, num_allocs=3, num_deallocs=4)
16751664

16761665

16771666
if __name__ == "__main__":

torch/_inductor/analyze_preserves_zero_mask.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import dataclasses
22
import itertools
3-
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
3+
from typing import Any, Optional, TYPE_CHECKING
44

55
import sympy
66

77
import torch
88
from torch._inductor import config
99
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
1010
from torch._inductor.index_propagation import SymPyOps, TypedExpr
11-
from torch._prims_common import type_to_dtype
1211

1312
from .ops_handler import DefaultHandler
1413
from .virtualized import StoreMode, V
@@ -110,32 +109,11 @@ def check_bounds(
110109
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
111110
return sympy.S.Zero
112111

113-
def masked(
114-
self,
115-
mask: DTypeContainer,
116-
body: Callable[[], DTypeContainer],
117-
other: DTypeContainer,
118-
) -> DTypeContainer:
119-
return self.where(mask, other, body())
120-
121112
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
122-
def to_constant(c: Union[int, float]) -> DTypeContainer:
123-
return DTypeContainer(type_to_dtype(type(c)), is_scalar=True)
124-
125-
args = tuple(
126-
a if not isinstance(a, (int, float)) else to_constant(a) for a in args
127-
)
128-
kwargs = {
129-
k: v if not isinstance(v, (int, float)) else to_constant(v)
130-
for k, v in kwargs.items()
131-
}
132-
133113
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
134-
is_scalar = all(
135-
not isinstance(v, DTypeContainer) or v.is_scalar
136-
for v in itertools.chain(args, kwargs.values())
137-
)
138-
out = DTypeContainer(out_dtype, is_scalar=is_scalar)
114+
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
115+
if name == "constant":
116+
return DTypeContainer(torch.float, is_scalar=True)
139117

140118
uses_low_prec = any(
141119
isinstance(dtype_cont, DTypeContainer)

torch/_inductor/ir.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4365,19 +4365,7 @@ def dummy(index, rindex): # type: ignore[no-untyped-def]
43654365
)
43664366

43674367
for inp in self.inputs:
4368-
layout = inp.layout
4369-
4370-
# we dont know what the iteration order is of the template,
4371-
# so we just want to make a single, contiguous dependency
4372-
if not layout.is_contiguous():
4373-
layout = FixedLayout(
4374-
device=layout.device,
4375-
dtype=layout.dtype,
4376-
size=layout.size,
4377-
stride=FlexibleLayout.contiguous_strides(layout.size),
4378-
offset=layout.offset,
4379-
)
4380-
indexer = layout.make_indexer()
4368+
indexer = inp.layout.make_indexer()
43814369

43824370
def dummy(index, rindex): # type: ignore[no-untyped-def]
43834371
assert len(rindex) == 0

torch/_inductor/scheduler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,6 +3460,24 @@ def check_prologue_fusion_heuristics_fusable(
34603460
why("prologue fusion will not increase amount of bytes read in kernel")
34613461
return False
34623462

3463+
# we want to avoid attempting to fuse predictably unprofitable prologues
3464+
# such as increasing the unaligned reads or writes.
3465+
# TODO - would be nice to generalize this, however, we would need more explicit
3466+
# knowledge of memory access patterns in the TritonTemplate in order to know
3467+
# the stride order to check alignment.
3468+
origins = tuple(
3469+
e.target
3470+
for n in prologue_node.get_nodes()
3471+
if n.node is not None
3472+
for e in n.node.get_origins()
3473+
if e.op == "call_function"
3474+
)
3475+
if origins == (torch.ops.aten.constant_pad_nd.default,):
3476+
why(
3477+
"prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads"
3478+
)
3479+
return False
3480+
34633481
def low_prec_fp(dtype: torch.dtype) -> bool:
34643482
return dtype.itemsize <= 2 and dtype.is_floating_point
34653483

0 commit comments

Comments
 (0)