Skip to content

Commit 9eebda9

Browse files
laithsakkapytorchmergebot
authored andcommitted
make narrow_tensor_symint DDE-free (pytorch#166379)
pytorch#158081 Pull Request resolved: pytorch#166379 Approved by: https://github.com/Lucaskabela ghstack dependencies: pytorch#166361
1 parent 09d8953 commit 9eebda9

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,8 +1784,8 @@ Tensor narrow_tensor_symint(
17841784
start.dim() == 0 &&
17851785
isIntegralType(start.scalar_type(), /*includeBool=*/false),
17861786
"start must be an 0-dim integral Tensor.");
1787-
int64_t st = start.item<int64_t>();
1788-
return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
1787+
c10::SymInt st = start.item().toSymInt();
1788+
return at::narrow_symint(self, dim, std::move(st), std::move(length));
17891789
}
17901790

17911791
std::

test/functorch/test_aotdispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8126,7 +8126,7 @@ def fn(x):
81268126
xfail("corrcoef"),
81278127
xfail("quantile"),
81288128
xfail("nanquantile"),
8129-
xfail("narrow"),
8129+
skip("narrow"),
81308130
xfail("istft"),
81318131
xfail("linalg.eig"),
81328132
skip("as_strided_scatter"),

test/test_dynamic_shapes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4452,6 +4452,19 @@ def test_narrow_unbacked_start_cpp_wrapper(self):
44524452
"""Test narrow with unbacked start with cpp_wrapper"""
44534453
self.test_narrow_unbacked_start()
44544454

4455+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
4456+
def test_narrow_with_tensor_start(self):
4457+
@torch.compile(backend="inductor", fullgraph=True)
4458+
def f(x, start, end):
4459+
return torch.narrow(x, 0, start, end)
4460+
4461+
x = torch.tensor(
4462+
[False], device="cuda:0" if torch.cuda.is_available() else "cpu"
4463+
)
4464+
start = torch.tensor(0)
4465+
res = f(x, start, 0)
4466+
self.assertEqual(res.shape, torch.Size([0]))
4467+
44554468

44564469
instantiate_parametrized_tests(TestUnbacked)
44574470

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,6 @@ def f(t):
19871987
}
19881988

19891989
only_fake_tensor_failures = {
1990-
xfail('narrow'),
19911990
xfail('tensor_split'),
19921991
}
19931992

0 commit comments

Comments
 (0)