Skip to content

Commit 9bd0830

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] guard_or_false for cat, repeat (pytorch#155290)
Summary: assumes: - specified repeats are non-negative - 1d cat arguments like [u0] aren't non-zero sized (replaces existing size-oblivious) Test Plan: test_export Rollback Plan: Differential Revision: D76092011 Pull Request resolved: pytorch#155290 Approved by: https://github.com/laithsakka
1 parent 4609699 commit 9bd0830

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,16 @@ def fn(x):
579579
self.assertEqual(obj.y, x + 1)
580580
self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
581581

582+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
583+
def test_unbacked_repeat_cat(self):
584+
def f(x, n):
585+
m = x.item()
586+
x = torch.empty(x).repeat(n) # s0*u0
587+
return torch.cat([x, x], dim=0)
588+
589+
fn = torch.compile(f, backend="eager", dynamic=True, fullgraph=True)
590+
fn(torch.tensor([5]), 5)
591+
582592
def test_tensor_setattr_getset_descriptor(self):
583593
# Tensor attribute `real` has special getter/setter for complex dtype.
584594
def f(x):

torch/_meta_registrations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4078,6 +4078,11 @@ def meta_repeat(self, repeats):
40784078
len(repeats) >= self.dim(),
40794079
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
40804080
)
4081+
for i, rep in enumerate(repeats):
4082+
torch._check(
4083+
rep >= 0,
4084+
lambda: f"Repeats cannot be negative, found {rep} at index {i}",
4085+
)
40814086
# Add new leading dimensions to the tensor if the
40824087
# number of target dimensions is larger than the
40834088
# number of source dimensions.

torch/_refs/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,10 @@ def cat_compute_output_memory_format(inputs):
27692769

27702770
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
27712771

2772-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
2772+
from torch.fx.experimental.symbolic_shapes import (
2773+
guard_or_false,
2774+
guard_size_oblivious,
2775+
)
27732776

27742777
# This is a bit tricky. Naively, you would expect to just pick one
27752778
# arbitrary tensor and check that all tensors match this tensor. However,
@@ -2830,7 +2833,7 @@ def cat_compute_output_memory_format(inputs):
28302833
)
28312834
else:
28322835
# Remove inputs that are 1-D, zero size
2833-
if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0):
2836+
if tensor.ndim == 1 and guard_or_false(tensor.shape[0] == 0):
28342837
continue
28352838
# Don't bother checking size match, prims.cat will handle it
28362839
filtered.append(tensor)

0 commit comments

Comments
 (0)