Skip to content

Commit fc340d0

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Allow comparing device w/o index with device w/ index (pytorch#159665)
In the case where we have expected device "cuda" and given device "cuda:0" I think we should succeed? Pull Request resolved: pytorch#159665 Approved by: https://github.com/yushangdi
1 parent 53e47af commit fc340d0

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

aten/src/ATen/native/ComparisonUtils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
2424
}
2525
}
2626

27+
template<>
28+
void _assert_match<c10::Device, std::optional<c10::Device>>(
29+
const c10::Device& original,
30+
const std::optional<c10::Device>& compared,
31+
const std::string& name) {
32+
if (compared) {
33+
const c10::Device& expected = compared.value();
34+
if (original.type() != expected.type()) {
35+
std::stringstream msg;
36+
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
37+
throw std::runtime_error(msg.str());
38+
}
39+
40+
// If the expected device doesn't have an index (e.g., just "cuda"),
41+
// or if both devices have the same index, consider them equal
42+
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
43+
std::stringstream msg;
44+
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
45+
throw std::runtime_error(msg.str());
46+
}
47+
}
48+
}
49+
2750
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
2851
_assert_match(tensor.sym_sizes(), sizes, "sizes");
2952
_assert_match(tensor.sym_strides(), strides, "strides");

test/export/test_export.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
OutputSpec,
6060
TensorArgument,
6161
)
62+
from torch.export.passes import move_to_device_pass
6263
from torch.fx.experimental.proxy_tensor import make_fx
6364
from torch.fx.experimental.symbolic_shapes import ShapeEnv
6465
from torch.testing import FileCheck
@@ -15914,6 +15915,22 @@ def forward(self, x):
1591415915
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
1591515916
)
1591615917

15918+
@requires_cuda
15919+
def test_assert_tensor_metadata_device_index(self):
15920+
class N(torch.nn.Module):
15921+
def __init__(self):
15922+
super().__init__()
15923+
15924+
def forward(self, x, y):
15925+
x = x.float()
15926+
y = y.float()
15927+
return x + y
15928+
15929+
inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda"))
15930+
ep = export(N(), inp)
15931+
ep = move_to_device_pass(ep, {"cuda:0": "cuda"})
15932+
ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0"))
15933+
1591715934
def test_input_output_no_stacktrace(self):
1591815935
class M(torch.nn.Module):
1591915936
def forward(self, x):

0 commit comments

Comments
 (0)