Skip to content

Commit a93f9dc

Browse files
Improve tensor comparison for correctness checks
ghstack-source-id: 4119214 Pull-Request: #44
1 parent 85547e1 commit a93f9dc

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

facto/inputgen/argtuple/gen.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@
1616
from facto.inputgen.utils.config import Condition, TensorConfig
1717

1818

19+
def compare_tensors(x, y):
20+
x = x.to(torch.float)
21+
y = y.to(torch.float)
22+
c = torch.isclose(x, y, atol=1e-4, rtol=1e-2, equal_nan=True)
23+
if c.all():
24+
return True, 0, lambda t: t
25+
else:
26+
d = x[c == False] - y[c == False]
27+
x_d = x[c == False]
28+
y_d = y[c == False]
29+
xy = torch.logical_or(
30+
torch.logical_and(x_d > 1e30, y_d == float("inf")),
31+
torch.logical_and(x_d == float("inf"), y_d > 1e30),
32+
)
33+
if xy.all():
34+
return True, 0, lambda t: t
35+
else:
36+
ix = xy == False
37+
return False, d[ix].abs().max().item(), lambda t: t[c == False][ix]
38+
39+
1940
class ArgumentTupleGenerator:
2041
def __init__(self, spec: Spec, config: Optional[TensorConfig] = None):
2142
self.spec = spec
@@ -188,21 +209,23 @@ def gen_errors(
188209
if isinstance(ret, torch.Tensor) and isinstance(
189210
cpu_ret, torch.Tensor
190211
):
191-
if not torch.allclose(
192-
cpu_ret, ret.to("cpu"), equal_nan=True
193-
):
194-
cpu_ret_f = cpu_ret.to(torch.float)
195-
ret_f = ret.to("cpu").to(torch.float)
196-
197-
max_diff = (cpu_ret_f - ret_f).abs().max()
212+
allclose, max_diff, ix_fun = compare_tensors(
213+
cpu_ret, ret.to("cpu")
214+
)
215+
if not allclose:
198216
if verbose:
199217
print(f"Output mismatch: {max_diff}")
200218
print(
201219
op.__name__, str([str(x) for x in meta_tuple])
202220
)
203-
if ret.numel() < 10:
204-
print(ret)
205-
print(cpu_ret)
221+
# for cpu_arg in cpu_posargs:
222+
# if isinstance(cpu_arg, torch.Tensor) and cpu_arg.numel() < 10:
223+
# print(cpu_arg)
224+
# elif isinstance(cpu_arg, torch.Tensor) and cpu_arg.shape == cpu_ret.shape:
225+
# print(ix_fun(cpu_arg))
226+
# if ret.numel() < 10:
227+
# print(ret)
228+
# print(cpu_ret)
206229
yield posargs, inkwargs, outargs
207230
else:
208231
# When valid=False, we expect failure, so success IS a bug

0 commit comments

Comments
 (0)