Skip to content

Commit da927e1

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Test correctness mps ops (#41)
Summary: Pull Request resolved: #41 imported-using-ghimport Test Plan: Imported from OSS Rollback Plan: Reviewed By: digantdesai Differential Revision: D80468306 Pulled By: manuelcandales fbshipit-source-id: e26d7be2a97d5906d2530c0f5012efd44f6781c7
1 parent 65cacec commit da927e1

File tree

3 files changed

+88
-8
lines changed

3 files changed

+88
-8
lines changed

facto/inputgen/argtuple/gen.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from copy import deepcopy
99
from typing import Any, Generator, List, Optional, Tuple
1010

11+
import torch
1112
from facto.inputgen.argtuple.engine import MetaArgTupleEngine
1213
from facto.inputgen.argument.engine import MetaArg
1314
from facto.inputgen.argument.gen import ArgumentGenerator
@@ -99,7 +100,13 @@ def gen(
99100
yield self.gen_tuple(meta_tuple, out=out)
100101

101102
def gen_errors(
102-
self, op, *, valid: bool = True, out: bool = False, verbose: bool = False
103+
self,
104+
op,
105+
*,
106+
valid: bool = True,
107+
out: bool = False,
108+
verbose: bool = False,
109+
check_correctness: bool = False,
103110
) -> Generator[
104111
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
105112
]:
@@ -128,15 +135,64 @@ def gen_errors(
128135
# Try to execute the operation with the generated inputs
129136
if out:
130137
# If there are output arguments, include them in the call
131-
op(*posargs, **inkwargs, **outargs)
138+
ret = op(*posargs, **inkwargs, **outargs)
132139
else:
133140
# Otherwise, just call with positional and keyword arguments
134-
op(*posargs, **inkwargs)
141+
ret = op(*posargs, **inkwargs)
135142

136143
# If execution succeeds:
137144
if valid:
138145
# When valid=True, we expect success, so this is NOT a bug
139-
continue
146+
if (
147+
check_correctness
148+
and self.config is not None
149+
and self.config.device != "cpu"
150+
):
151+
# If correctness=True, and device != cpu we also check if the output is correct
152+
# by comparing it to the cpu output
153+
cpu_posargs = []
154+
cpu_inkwargs = OrderedDict()
155+
cpu_outargs = OrderedDict()
156+
for arg in posargs:
157+
new = arg
158+
if isinstance(arg, torch.Tensor):
159+
new = arg.to("cpu")
160+
cpu_posargs.append(new)
161+
for k, v in inkwargs.items():
162+
new = v
163+
if isinstance(v, torch.Tensor):
164+
new = v.to("cpu")
165+
cpu_inkwargs[k] = new
166+
for k, v in outargs.items():
167+
new = v
168+
if isinstance(v, torch.Tensor):
169+
new = v.to("cpu")
170+
cpu_outargs[k] = new
171+
172+
try:
173+
cpu_ret = op(*cpu_posargs, **cpu_inkwargs, **cpu_outargs)
174+
except Exception:
175+
continue
176+
177+
if isinstance(ret, torch.Tensor) and isinstance(
178+
cpu_ret, torch.Tensor
179+
):
180+
if not torch.allclose(
181+
cpu_ret, ret.to("cpu"), equal_nan=True
182+
):
183+
cpu_ret_f = cpu_ret.to(torch.float)
184+
ret_f = ret.to("cpu").to(torch.float)
185+
186+
max_diff = (cpu_ret_f - ret_f).abs().max()
187+
if verbose:
188+
print(f"Output mismatch: {max_diff}")
189+
print(
190+
op.__name__, str([str(x) for x in meta_tuple])
191+
)
192+
if ret.numel() < 10:
193+
print(ret)
194+
print(cpu_ret)
195+
yield posargs, inkwargs, outargs
140196
else:
141197
# When valid=False, we expect failure, so success IS a bug
142198
if verbose:

test/specdb/base_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
class BaseSpecDBTest(unittest.TestCase):
1717
"""Base test class for validating all specs in SpecDB using gen_errors."""
1818

19-
def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
19+
def _run_op(
20+
self,
21+
op_name: str,
22+
*,
23+
config: Optional[TensorConfig] = None,
24+
check_correctness: bool = False,
25+
):
2026
"""
2127
Run a single op in SpecDB with a given TensorConfig
2228
@@ -37,7 +43,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
3743

3844
try:
3945
errors = list(
40-
generator.gen_errors(op, valid=True, out=False, verbose=True)
46+
generator.gen_errors(
47+
op,
48+
valid=True,
49+
out=False,
50+
verbose=True,
51+
check_correctness=check_correctness,
52+
)
4153
)
4254
except Exception as e:
4355
self.fail(f"Failed while testing operation {op_name}: {e}")
@@ -47,7 +59,13 @@ def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
4759
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
4860
)
4961

50-
def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
62+
def _run_all_ops(
63+
self,
64+
*,
65+
config: Optional[TensorConfig] = None,
66+
skip_ops=[],
67+
check_correctness: bool = False,
68+
):
5169
"""
5270
Run all ops in SpecDB with a given TensorConfig
5371
@@ -61,4 +79,4 @@ def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
6179
for op_name in op_names:
6280
if op_name in skip_ops:
6381
continue
64-
self._run_op(op_name, config=config)
82+
self._run_op(op_name, config=config, check_correctness=check_correctness)

test/specdb/test_specdb_mps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def test_all_ops_mps_strided(self):
9595
)
9696
self._run_all_ops(config=config, skip_ops=skip_ops)
9797

98+
def test_correctness_all_ops_mps(self):
99+
config = TensorConfig(
100+
device="mps", disallow_dtypes=[torch.float64], half_precision=False
101+
)
102+
self._run_all_ops(config=config, skip_ops=self.SKIP_OPS, check_correctness=True)
103+
98104

99105
if __name__ == "__main__":
100106
unittest.main()

0 commit comments

Comments
 (0)