|
8 | 8 | from copy import deepcopy
|
9 | 9 | from typing import Any, Generator, List, Optional, Tuple
|
10 | 10 |
|
| 11 | +import torch |
11 | 12 | from facto.inputgen.argtuple.engine import MetaArgTupleEngine
|
12 | 13 | from facto.inputgen.argument.engine import MetaArg
|
13 | 14 | from facto.inputgen.argument.gen import ArgumentGenerator
|
@@ -99,7 +100,13 @@ def gen(
|
99 | 100 | yield self.gen_tuple(meta_tuple, out=out)
|
100 | 101 |
|
101 | 102 | def gen_errors(
|
102 |
| - self, op, *, valid: bool = True, out: bool = False, verbose: bool = False |
| 103 | + self, |
| 104 | + op, |
| 105 | + *, |
| 106 | + valid: bool = True, |
| 107 | + out: bool = False, |
| 108 | + verbose: bool = False, |
| 109 | + check_correctness: bool = False, |
103 | 110 | ) -> Generator[
|
104 | 111 | Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
|
105 | 112 | ]:
|
@@ -128,15 +135,64 @@ def gen_errors(
|
128 | 135 | # Try to execute the operation with the generated inputs
|
129 | 136 | if out:
|
130 | 137 | # If there are output arguments, include them in the call
|
131 |
| - op(*posargs, **inkwargs, **outargs) |
| 138 | + ret = op(*posargs, **inkwargs, **outargs) |
132 | 139 | else:
|
133 | 140 | # Otherwise, just call with positional and keyword arguments
|
134 |
| - op(*posargs, **inkwargs) |
| 141 | + ret = op(*posargs, **inkwargs) |
135 | 142 |
|
136 | 143 | # If execution succeeds:
|
137 | 144 | if valid:
|
138 | 145 | # When valid=True, we expect success, so this is NOT a bug
|
139 |
| - continue |
| 146 | + if ( |
| 147 | + check_correctness |
| 148 | + and self.config is not None |
| 149 | + and self.config.device != "cpu" |
| 150 | + ): |
| 151 | + # If correctness=True, and device != cpu we also check if the output is correct |
| 152 | + # by comparing it to the cpu output |
| 153 | + cpu_posargs = [] |
| 154 | + cpu_inkwargs = OrderedDict() |
| 155 | + cpu_outargs = OrderedDict() |
| 156 | + for arg in posargs: |
| 157 | + new = arg |
| 158 | + if isinstance(arg, torch.Tensor): |
| 159 | + new = arg.to("cpu") |
| 160 | + cpu_posargs.append(new) |
| 161 | + for k, v in inkwargs.items(): |
| 162 | + new = v |
| 163 | + if isinstance(v, torch.Tensor): |
| 164 | + new = v.to("cpu") |
| 165 | + cpu_inkwargs[k] = new |
| 166 | + for k, v in outargs.items(): |
| 167 | + new = v |
| 168 | + if isinstance(v, torch.Tensor): |
| 169 | + new = v.to("cpu") |
| 170 | + cpu_outargs[k] = new |
| 171 | + |
| 172 | + try: |
| 173 | + cpu_ret = op(*cpu_posargs, **cpu_inkwargs, **cpu_outargs) |
| 174 | + except Exception: |
| 175 | + continue |
| 176 | + |
| 177 | + if isinstance(ret, torch.Tensor) and isinstance( |
| 178 | + cpu_ret, torch.Tensor |
| 179 | + ): |
| 180 | + if not torch.allclose( |
| 181 | + cpu_ret, ret.to("cpu"), equal_nan=True |
| 182 | + ): |
| 183 | + cpu_ret_f = cpu_ret.to(torch.float) |
| 184 | + ret_f = ret.to("cpu").to(torch.float) |
| 185 | + |
| 186 | + max_diff = (cpu_ret_f - ret_f).abs().max() |
| 187 | + if verbose: |
| 188 | + print(f"Output mismatch: {max_diff}") |
| 189 | + print( |
| 190 | + op.__name__, str([str(x) for x in meta_tuple]) |
| 191 | + ) |
| 192 | + if ret.numel() < 10: |
| 193 | + print(ret) |
| 194 | + print(cpu_ret) |
| 195 | + yield posargs, inkwargs, outargs |
140 | 196 | else:
|
141 | 197 | # When valid=False, we expect failure, so success IS a bug
|
142 | 198 | if verbose:
|
|
0 commit comments