Skip to content

Commit 53ab730

Browse files
ColinPepplerpytorchmergebot
authored andcommitted
[inductor] support unbacked symint in sdpfa (pytorch#157739)
Pull Request resolved: pytorch#157739 Approved by: https://github.com/laithsakka
1 parent 08e9dd2 commit 53ab730

File tree

4 files changed

+55
-9
lines changed

4 files changed

+55
-9
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,31 @@ def fn(q, k, vector, scalar):
489489
expected = fn(*example_inputs)
490490
torch.testing.assert_close(actual, expected)
491491

492+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
493+
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
494+
def test_sdpfa(self, device):
495+
if device == "cpu":
496+
raise unittest.SkipTest(
497+
"scaled_dot_product_flash_attention has no CPU backend"
498+
)
499+
500+
def fn(x):
501+
B, H, d_h = 2, 4, 8
502+
nz = torch.nonzero(x)
503+
seq_len = nz.size(0)
504+
505+
q = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
506+
k = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
507+
v = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
508+
509+
result = torch.ops.aten._scaled_dot_product_flash_attention.default(
510+
q, k, v, dropout_p=0.0, is_causal=False, scale=None
511+
)
512+
return result
513+
514+
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
515+
torch.compile(fn, fullgraph=True)(x)
516+
492517

493518
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
494519

torch/_inductor/ir.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,16 @@ def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool:
437437
return get_device_type(x) == "cpu"
438438

439439

440-
def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> bool:
441-
if not isinstance(x, IRNode) or x.maybe_get_stride() is None:
440+
def is_aligned_realized_tensor_hint(
441+
x: Union[Buffer, TensorBox], alignment: int
442+
) -> bool:
443+
# Use this as a hint. This won't guard since size_hint doesn't guard.
444+
if (
445+
not isinstance(x, IRNode)
446+
or x.maybe_get_stride() is None
447+
or free_unbacked_symbols(x.get_stride())
448+
or free_unbacked_symbols(x.get_size())
449+
):
442450
return False
443451

444452
aligned_strides = all(
@@ -5674,17 +5682,21 @@ def require_strides(
56745682
# the current size and stride already satisfies this order.
56755683
# However by freezing it to the required order, the layout will be changed to:
56765684
# size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.
5677-
5685+
use_current_stride_order = is_stride_order_storage_and_layout(
5686+
x, order
5687+
) and not free_unbacked_symbols(x.get_layout().stride)
56785688
# fix flexiblelayout to be FixedLayout with stride_order
56795689
as_storage_and_layout(
56805690
x,
56815691
freeze=True,
56825692
want_contiguous=False,
56835693
stride_order=(
56845694
get_stride_order(
5685-
V.graph.sizevars.size_hints(x.get_layout().stride)
5695+
V.graph.sizevars.size_hints_or_throw(
5696+
x.get_layout().stride
5697+
)
56865698
)
5687-
if is_stride_order_storage_and_layout(x, order)
5699+
if use_current_stride_order
56885700
else order
56895701
),
56905702
allow_padding=allow_padding,

torch/_inductor/lowering.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,15 +2538,16 @@ def apply_constraint(idx, arg, fx_arg):
25382538
if len(arg.get_size()) not in (3, 4):
25392539
return arg
25402540

2541-
if ir.is_aligned_realized_tensor(arg, ALIGNMENT):
2541+
is_aligned_tensor = ir.is_aligned_realized_tensor_hint(arg, ALIGNMENT)
2542+
if is_aligned_tensor:
25422543
return ir.try_match_insignificant_strides(
25432544
ir.ExternKernel.realize_input(arg), meta_stride_expr
25442545
)
25452546

25462547
if (
25472548
isinstance(arg, IRNode)
25482549
and arg.maybe_get_stride() is not None
2549-
and ir.is_aligned_realized_tensor(arg, ALIGNMENT)
2550+
and is_aligned_tensor
25502551
):
25512552
return ir.try_match_insignificant_strides(
25522553
ir.ExternKernel.realize_input(arg), meta_stride_expr
@@ -2590,15 +2591,15 @@ def apply_constraint(idx, arg, fx_arg):
25902591

25912592
return ir.ExternKernel.require_exact_strides(arg, out_strides)
25922593

2593-
if ir.is_aligned_realized_tensor(arg, ALIGNMENT):
2594+
if is_aligned_tensor:
25942595
return ir.try_match_insignificant_strides(
25952596
ir.ExternKernel.realize_input(arg), meta_stride_expr
25962597
)
25972598

25982599
if (
25992600
isinstance(arg, IRNode)
26002601
and arg.maybe_get_stride() is not None
2601-
and ir.is_aligned_realized_tensor(arg, ALIGNMENT)
2602+
and is_aligned_tensor
26022603
):
26032604
return ir.try_match_insignificant_strides(
26042605
ir.ExternKernel.realize_input(arg), meta_stride_expr

torch/_inductor/sizevars.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def size_hint(
559559
raise
560560

561561
def size_hint_or_throw(self, expr: Union[Expr, int]) -> int:
562+
# Like size_hint but there's no fallback for unbacked symints, so it throws.
562563
out = self.symbolic_hint(expr)
563564
try:
564565
return int(out)
@@ -574,6 +575,13 @@ def size_hints(
574575
) -> tuple[int, ...]:
575576
return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
576577

578+
def size_hints_or_throw(
579+
self,
580+
exprs: Iterable[Union[Expr, int]],
581+
) -> tuple[int, ...]:
582+
# Like size_hints but there's no fallback for unbacked symints, so it throws.
583+
return tuple(self.size_hint_or_throw(x) for x in exprs)
584+
577585
def _lru_cache(self, fn, maxsize=None):
578586
"""
579587
Wrapper around functools.lru_cache that clears when replacements

0 commit comments

Comments
 (0)