|
16 | 16 | from facto.inputgen.utils.config import Condition, TensorConfig
|
17 | 17 |
|
18 | 18 |
|
| 19 | +def compare_tensors(x, y): |
| 20 | + x = x.to(torch.float) |
| 21 | + y = y.to(torch.float) |
| 22 | + c = torch.isclose(x, y, atol=1e-4, rtol=1e-2, equal_nan=True) |
| 23 | + if c.all(): |
| 24 | + return True, 0, lambda t: t |
| 25 | + else: |
| 26 | + d = x[c == False] - y[c == False] |
| 27 | + x_d = x[c == False] |
| 28 | + y_d = y[c == False] |
| 29 | + xy = torch.logical_or( |
| 30 | + torch.logical_and(x_d > 1e30, y_d == float("inf")), |
| 31 | + torch.logical_and(x_d == float("inf"), y_d > 1e30), |
| 32 | + ) |
| 33 | + if xy.all(): |
| 34 | + return True, 0, lambda t: t |
| 35 | + else: |
| 36 | + ix = xy == False |
| 37 | + return False, d[ix].abs().max().item(), lambda t: t[c == False][ix] |
| 38 | + |
| 39 | + |
19 | 40 | class ArgumentTupleGenerator:
|
20 | 41 | def __init__(self, spec: Spec, config: Optional[TensorConfig] = None):
|
21 | 42 | self.spec = spec
|
@@ -188,21 +209,23 @@ def gen_errors(
|
188 | 209 | if isinstance(ret, torch.Tensor) and isinstance(
|
189 | 210 | cpu_ret, torch.Tensor
|
190 | 211 | ):
|
191 |
| - if not torch.allclose( |
192 |
| - cpu_ret, ret.to("cpu"), equal_nan=True |
193 |
| - ): |
194 |
| - cpu_ret_f = cpu_ret.to(torch.float) |
195 |
| - ret_f = ret.to("cpu").to(torch.float) |
196 |
| - |
197 |
| - max_diff = (cpu_ret_f - ret_f).abs().max() |
| 212 | + allclose, max_diff, ix_fun = compare_tensors( |
| 213 | + cpu_ret, ret.to("cpu") |
| 214 | + ) |
| 215 | + if not allclose: |
198 | 216 | if verbose:
|
199 | 217 | print(f"Output mismatch: {max_diff}")
|
200 | 218 | print(
|
201 | 219 | op.__name__, str([str(x) for x in meta_tuple])
|
202 | 220 | )
|
203 |
| - if ret.numel() < 10: |
204 |
| - print(ret) |
205 |
| - print(cpu_ret) |
| 221 | + # for cpu_arg in cpu_posargs: |
| 222 | + # if isinstance(cpu_arg, torch.Tensor) and cpu_arg.numel() < 10: |
| 223 | + # print(cpu_arg) |
| 224 | + # elif isinstance(cpu_arg, torch.Tensor) and cpu_arg.shape == cpu_ret.shape: |
| 225 | + # print(ix_fun(cpu_arg)) |
| 226 | + # if ret.numel() < 10: |
| 227 | + # print(ret) |
| 228 | + # print(cpu_ret) |
206 | 229 | yield posargs, inkwargs, outargs
|
207 | 230 | else:
|
208 | 231 | # When valid=False, we expect failure, so success IS a bug
|
|
0 commit comments