diff --git a/facto/inputgen/argtuple/gen.py b/facto/inputgen/argtuple/gen.py index cb6e6f4..390b3ca 100644 --- a/facto/inputgen/argtuple/gen.py +++ b/facto/inputgen/argtuple/gen.py @@ -16,6 +16,27 @@ from facto.inputgen.utils.config import Condition, TensorConfig +def compare_tensors(x, y): + x = x.to(torch.float) + y = y.to(torch.float) + c = torch.isclose(x, y, atol=1e-4, rtol=1e-2, equal_nan=True) + if c.all(): + return True, 0, lambda t: t + else: + d = x[c == False] - y[c == False] + x_d = x[c == False] + y_d = y[c == False] + xy = torch.logical_or( + torch.logical_and(x_d > 1e30, y_d == float("inf")), + torch.logical_and(x_d == float("inf"), y_d > 1e30), + ) + if xy.all(): + return True, 0, lambda t: t + else: + ix = xy == False + return False, d[ix].abs().max().item(), lambda t: t[c == False][ix] + + class ArgumentTupleGenerator: def __init__(self, spec: Spec, config: Optional[TensorConfig] = None): self.spec = spec @@ -188,21 +209,23 @@ def gen_errors( if isinstance(ret, torch.Tensor) and isinstance( cpu_ret, torch.Tensor ): - if not torch.allclose( - cpu_ret, ret.to("cpu"), equal_nan=True - ): - cpu_ret_f = cpu_ret.to(torch.float) - ret_f = ret.to("cpu").to(torch.float) - - max_diff = (cpu_ret_f - ret_f).abs().max() + allclose, max_diff, ix_fun = compare_tensors( + cpu_ret, ret.to("cpu") + ) + if not allclose: if verbose: print(f"Output mismatch: {max_diff}") print( op.__name__, str([str(x) for x in meta_tuple]) ) - if ret.numel() < 10: - print(ret) - print(cpu_ret) + # for cpu_arg in cpu_posargs: + # if isinstance(cpu_arg, torch.Tensor) and cpu_arg.numel() < 10: + # print(cpu_arg) + # elif isinstance(cpu_arg, torch.Tensor) and cpu_arg.shape == cpu_ret.shape: + # print(ix_fun(cpu_arg)) + # if ret.numel() < 10: + # print(ret) + # print(cpu_ret) yield posargs, inkwargs, outargs else: # When valid=False, we expect failure, so success IS a bug