diff --git a/facto/inputgen/argtuple/gen.py b/facto/inputgen/argtuple/gen.py index b139753..e9ef87a 100644 --- a/facto/inputgen/argtuple/gen.py +++ b/facto/inputgen/argtuple/gen.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import Any, Generator, List, Optional, Tuple +import torch from facto.inputgen.argtuple.engine import MetaArgTupleEngine from facto.inputgen.argument.engine import MetaArg from facto.inputgen.argument.gen import ArgumentGenerator @@ -99,7 +100,13 @@ def gen( yield self.gen_tuple(meta_tuple, out=out) def gen_errors( - self, op, *, valid: bool = True, out: bool = False, verbose: bool = False + self, + op, + *, + valid: bool = True, + out: bool = False, + verbose: bool = False, + check_correctness: bool = False, ) -> Generator[ Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any ]: @@ -128,15 +135,64 @@ def gen_errors( # Try to execute the operation with the generated inputs if out: # If there are output arguments, include them in the call - op(*posargs, **inkwargs, **outargs) + ret = op(*posargs, **inkwargs, **outargs) else: # Otherwise, just call with positional and keyword arguments - op(*posargs, **inkwargs) + ret = op(*posargs, **inkwargs) # If execution succeeds: if valid: # When valid=True, we expect success, so this is NOT a bug - continue + if ( + check_correctness + and self.config is not None + and self.config.device != "cpu" + ): + # If correctness=True, and device != cpu we also check if the output is correct + # by comparing it to the cpu output + cpu_posargs = [] + cpu_inkwargs = OrderedDict() + cpu_outargs = OrderedDict() + for arg in posargs: + new = arg + if isinstance(arg, torch.Tensor): + new = arg.to("cpu") + cpu_posargs.append(new) + for k, v in inkwargs.items(): + new = v + if isinstance(v, torch.Tensor): + new = v.to("cpu") + cpu_inkwargs[k] = new + for k, v in outargs.items(): + new = v + if isinstance(v, torch.Tensor): + new = v.to("cpu") + cpu_outargs[k] = new + + try: + cpu_ret = op(*cpu_posargs, **cpu_inkwargs, **cpu_outargs) + except Exception: + continue + + if isinstance(ret, torch.Tensor) and isinstance( + cpu_ret, torch.Tensor + ): + if not torch.allclose( + cpu_ret, ret.to("cpu"), equal_nan=True + ): + cpu_ret_f = cpu_ret.to(torch.float) + ret_f = ret.to("cpu").to(torch.float) + + max_diff = (cpu_ret_f - ret_f).abs().max() + if verbose: + print(f"Output mismatch: {max_diff}") + print( + op.__name__, str([str(x) for x in meta_tuple]) + ) + if ret.numel() < 10: + print(ret) + print(cpu_ret) + yield posargs, inkwargs, outargs else: # When valid=False, we expect failure, so success IS a bug if verbose: diff --git a/test/specdb/base_test.py b/test/specdb/base_test.py index 55c8230..0195ff0 100644 --- a/test/specdb/base_test.py +++ b/test/specdb/base_test.py @@ -16,7 +16,13 @@ class BaseSpecDBTest(unittest.TestCase): """Base test class for validating all specs in SpecDB using gen_errors.""" - def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None): + def _run_op( + self, + op_name: str, + *, + config: Optional[TensorConfig] = None, + check_correctness: bool = False, + ): """ Run a single op in SpecDB with a given TensorConfig @@ -37,7 +43,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None): try: errors = list( - generator.gen_errors(op, valid=True, out=False, verbose=True) + generator.gen_errors( + op, + valid=True, + out=False, + verbose=True, + check_correctness=check_correctness, + ) ) except Exception as e: self.fail(f"Failed while testing operation {op_name}: {e}") @@ -47,7 +59,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None): f"Found {len(errors)} errors for {op_name} with valid=True, out=False" ) - def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]): + def _run_all_ops( + self, + *, + config: Optional[TensorConfig] = None, + skip_ops=[], + check_correctness: bool = False, + ): """ Run all ops in SpecDB with a given TensorConfig @@ -61,4 +79,4 @@ def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]): for op_name in op_names: if op_name in skip_ops: continue - self._run_op(op_name, config=config) + self._run_op(op_name, config=config, check_correctness=check_correctness) diff --git a/test/specdb/test_specdb_mps.py b/test/specdb/test_specdb_mps.py index fc8128c..ea59739 100644 --- a/test/specdb/test_specdb_mps.py +++ b/test/specdb/test_specdb_mps.py @@ -95,6 +95,12 @@ def test_all_ops_mps_strided(self): ) self._run_all_ops(config=config, skip_ops=skip_ops) + def test_correctness_all_ops_mps(self): + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], half_precision=False + ) + self._run_all_ops(config=config, skip_ops=self.SKIP_OPS, check_correctness=True) + if __name__ == "__main__": unittest.main()