Skip to content

Test correctness mps ops #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/manuelcandales/23/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions facto/inputgen/argtuple/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
from typing import Any, Generator, List, Optional, Tuple

import torch
from facto.inputgen.argtuple.engine import MetaArgTupleEngine
from facto.inputgen.argument.engine import MetaArg
from facto.inputgen.argument.gen import ArgumentGenerator
Expand Down Expand Up @@ -99,7 +100,13 @@ def gen(
yield self.gen_tuple(meta_tuple, out=out)

def gen_errors(
self, op, *, valid: bool = True, out: bool = False, verbose: bool = False
self,
op,
*,
valid: bool = True,
out: bool = False,
verbose: bool = False,
check_correctness: bool = False,
) -> Generator[
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
]:
Expand Down Expand Up @@ -128,15 +135,64 @@ def gen_errors(
# Try to execute the operation with the generated inputs
if out:
# If there are output arguments, include them in the call
op(*posargs, **inkwargs, **outargs)
ret = op(*posargs, **inkwargs, **outargs)
else:
# Otherwise, just call with positional and keyword arguments
op(*posargs, **inkwargs)
ret = op(*posargs, **inkwargs)

# If execution succeeds:
if valid:
# When valid=True, we expect success, so this is NOT a bug
continue
if (
check_correctness
and self.config is not None
and self.config.device != "cpu"
):
# If correctness=True, and device != cpu we also check if the output is correct
# by comparing it to the cpu output
cpu_posargs = []
cpu_inkwargs = OrderedDict()
cpu_outargs = OrderedDict()
for arg in posargs:
new = arg
if isinstance(arg, torch.Tensor):
new = arg.to("cpu")
cpu_posargs.append(new)
for k, v in inkwargs.items():
new = v
if isinstance(v, torch.Tensor):
new = v.to("cpu")
cpu_inkwargs[k] = new
for k, v in outargs.items():
new = v
if isinstance(v, torch.Tensor):
new = v.to("cpu")
cpu_outargs[k] = new

try:
cpu_ret = op(*cpu_posargs, **cpu_inkwargs, **cpu_outargs)
except Exception:
continue

if isinstance(ret, torch.Tensor) and isinstance(
cpu_ret, torch.Tensor
):
if not torch.allclose(
cpu_ret, ret.to("cpu"), equal_nan=True
):
cpu_ret_f = cpu_ret.to(torch.float)
ret_f = ret.to("cpu").to(torch.float)

max_diff = (cpu_ret_f - ret_f).abs().max()
if verbose:
print(f"Output mismatch: {max_diff}")
print(
op.__name__, str([str(x) for x in meta_tuple])
)
if ret.numel() < 10:
print(ret)
print(cpu_ret)
yield posargs, inkwargs, outargs
else:
# When valid=False, we expect failure, so success IS a bug
if verbose:
Expand Down
26 changes: 22 additions & 4 deletions test/specdb/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
class BaseSpecDBTest(unittest.TestCase):
"""Base test class for validating all specs in SpecDB using gen_errors."""

def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
def _run_op(
self,
op_name: str,
*,
config: Optional[TensorConfig] = None,
check_correctness: bool = False,
):
"""
Run a single op in SpecDB with a given TensorConfig

Expand All @@ -37,7 +43,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):

try:
errors = list(
generator.gen_errors(op, valid=True, out=False, verbose=True)
generator.gen_errors(
op,
valid=True,
out=False,
verbose=True,
check_correctness=check_correctness,
)
)
except Exception as e:
self.fail(f"Failed while testing operation {op_name}: {e}")
Expand All @@ -47,7 +59,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
)

def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
def _run_all_ops(
self,
*,
config: Optional[TensorConfig] = None,
skip_ops=[],
check_correctness: bool = False,
):
"""
Run all ops in SpecDB with a given TensorConfig

Expand All @@ -61,4 +79,4 @@ def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
for op_name in op_names:
if op_name in skip_ops:
continue
self._run_op(op_name, config=config)
self._run_op(op_name, config=config, check_correctness=check_correctness)
6 changes: 6 additions & 0 deletions test/specdb/test_specdb_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def test_all_ops_mps_strided(self):
)
self._run_all_ops(config=config, skip_ops=skip_ops)

def test_correctness_all_ops_mps(self):
config = TensorConfig(
device="mps", disallow_dtypes=[torch.float64], half_precision=False
)
self._run_all_ops(config=config, skip_ops=self.SKIP_OPS, check_correctness=True)


if __name__ == "__main__":
unittest.main()