Skip to content

Commit 57dc468

Browse files
Test ops on MPS backend
ghstack-source-id: 0ddf828 Pull-Request: #36
1 parent f753ab5 commit 57dc468

File tree

8 files changed

+207
-76
lines changed

8 files changed

+207
-76
lines changed

facto/inputgen/argtuple/gen.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ def _apply_constraints_to_arg(self, arg, config: TensorConfig):
5353
size_constraint = cp.Size.Ge(lambda deps, r, d: 1)
5454
modified_arg.constraints = modified_arg.constraints + [size_constraint]
5555

56+
# Add dtype constraints for tensor arguments when dtypes are disallowed
57+
if config.disallow_dtypes:
58+
if arg.type.is_tensor():
59+
dtype_constraint = cp.Dtype.NotIn(lambda deps: config.disallow_dtypes)
60+
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]
61+
elif arg.type.is_tensor_list():
62+
dtype_constraint = cp.Dtype.NotIn(
63+
lambda deps, length, ix: config.disallow_dtypes
64+
)
65+
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]
66+
elif arg.type.is_scalar_type():
67+
dtype_constraint = cp.Value.NotIn(lambda deps: config.disallow_dtypes)
68+
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]
69+
5670
return modified_arg
5771

5872
def gen_tuple(
@@ -85,7 +99,7 @@ def gen(
8599
yield self.gen_tuple(meta_tuple, out=out)
86100

87101
def gen_errors(
88-
self, op, *, valid: bool = True, out: bool = False
102+
self, op, *, valid: bool = True, out: bool = False, verbose: bool = False
89103
) -> Generator[
90104
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
91105
]:
@@ -105,7 +119,11 @@ def gen_errors(
105119
Yields:
106120
Tuples of (posargs, inkwargs, outargs) that don't behave as expected
107121
"""
108-
for posargs, inkwargs, outargs in self.gen(valid=valid, out=out):
122+
123+
engine = MetaArgTupleEngine(self._modified_spec, out=out)
124+
for meta_tuple in engine.gen(valid=valid):
125+
posargs, inkwargs, outargs = self.gen_tuple(meta_tuple, out=out)
126+
109127
try:
110128
# Try to execute the operation with the generated inputs
111129
if out:
@@ -121,12 +139,18 @@ def gen_errors(
121139
continue
122140
else:
123141
# When valid=False, we expect failure, so success IS a bug
142+
if verbose:
143+
print(f"Unexpected success:")
144+
print(op.__name__, str([str(x) for x in meta_tuple]))
124145
yield posargs, inkwargs, outargs
125146

126-
except Exception:
147+
except Exception as e:
127148
# If execution fails:
128149
if valid:
129150
# When valid=True, we expect success, so failure IS a bug
151+
if verbose:
152+
print(op.__name__, str([str(x) for x in meta_tuple]))
153+
print(f"Exception occurred: {e}")
130154
yield posargs, inkwargs, outargs
131155
else:
132156
# When valid=False, we expect failure, so this is NOT a bug

facto/inputgen/utils/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ class Condition(str, Enum):
1212
ALLOW_TRANSPOSED = "transposed"
1313
ALLOW_PERMUTED = "permuted"
1414
ALLOW_STRIDED = "strided"
15+
DISALLOW_DTYPES = "disallow_dtypes"
1516

1617

1718
class TensorConfig:
18-
def __init__(self, device="cpu", **conditions):
19+
def __init__(self, device="cpu", disallow_dtypes=None, **conditions):
1920
self.device = device
21+
self.disallow_dtypes = disallow_dtypes or []
2022
self.conditions = {condition: False for condition in Condition}
2123
for condition, value in conditions.items():
2224
if condition in self.conditions:
@@ -26,6 +28,10 @@ def __init__(self, device="cpu", **conditions):
2628
def is_allowed(self, condition: Condition) -> bool:
2729
return self.conditions.get(condition, False)
2830

31+
def is_dtype_disallowed(self, dtype) -> bool:
32+
"""Check if a given dtype is in the disallow list."""
33+
return dtype in self.disallow_dtypes
34+
2935
def set_probability(self, probability: float) -> "TensorConfig":
3036
self.probability = probability
3137
return self

facto/inputgen/variable/solve.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, vtype: type):
3535

3636
def Eq(self, v: Any) -> None:
3737
if invalid_vtype(self.vtype, v):
38-
raise TypeError("Variable type mismatch")
38+
raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}")
3939
if self.space.empty():
4040
return
4141
if self.space.contains(v):
@@ -45,15 +45,17 @@ def Eq(self, v: Any) -> None:
4545

4646
def Ne(self, v: Any) -> None:
4747
if invalid_vtype(self.vtype, v):
48-
raise TypeError("Variable type mismatch")
48+
raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}")
4949
if self.space.empty():
5050
return
5151
self.space.remove(v)
5252

5353
def In(self, values: List[Any]) -> None:
5454
for v in values:
5555
if invalid_vtype(self.vtype, v):
56-
raise TypeError("Variable type mismatch")
56+
raise TypeError(
57+
f"Variable type mismatch: {v} is not of type {self.vtype}"
58+
)
5759
if self.space.empty():
5860
return
5961
self.space.discrete = Discrete(
@@ -63,7 +65,9 @@ def In(self, values: List[Any]) -> None:
6365
def NotIn(self, values: List[Any]) -> None:
6466
for v in values:
6567
if invalid_vtype(self.vtype, v):
66-
raise TypeError("Variable type mismatch")
68+
raise TypeError(
69+
f"Variable type mismatch: {v} is not of type {self.vtype}"
70+
)
6771
if self.space.empty():
6872
return
6973
for v in values:
@@ -73,7 +77,9 @@ def Le(self, upper: Union[bool, int, float]) -> None:
7377
if self.vtype not in [bool, int, float]:
7478
raise Exception(f"Le is not valid constraint on {self.vtype}")
7579
if invalid_vtype(self.vtype, upper):
76-
raise TypeError("Variable type mismatch")
80+
raise TypeError(
81+
f"Variable type mismatch: {upper} is not of type {self.vtype}"
82+
)
7783
if self.space.empty():
7884
return
7985
elif self.space.discrete.initialized:
@@ -90,7 +96,9 @@ def Lt(self, upper: Union[bool, int, float]) -> None:
9096
if self.vtype not in [bool, int, float]:
9197
raise Exception(f"Lt is not valid constraint on {self.vtype}")
9298
if invalid_vtype(self.vtype, upper):
93-
raise TypeError("Variable type mismatch")
99+
raise TypeError(
100+
f"Variable type mismatch: {upper} is not of type {self.vtype}"
101+
)
94102
if self.space.empty():
95103
return
96104
elif self.space.discrete.initialized:
@@ -109,7 +117,9 @@ def Ge(self, lower: Union[bool, int, float]) -> None:
109117
if self.vtype not in [bool, int, float]:
110118
raise Exception(f"Ge is not valid constraint on {self.vtype}")
111119
if invalid_vtype(self.vtype, lower):
112-
raise TypeError("Variable type mismatch")
120+
raise TypeError(
121+
f"Variable type mismatch: {lower} is not of type {self.vtype}"
122+
)
113123
if self.space.empty():
114124
return
115125
elif self.space.discrete.initialized:
@@ -126,7 +136,9 @@ def Gt(self, lower: Union[bool, int, float]) -> None:
126136
if self.vtype not in [bool, int, float]:
127137
raise Exception(f"Gt is not valid constraint on {self.vtype}")
128138
if invalid_vtype(self.vtype, lower):
129-
raise TypeError("Variable type mismatch")
139+
raise TypeError(
140+
f"Variable type mismatch: {lower} is not of type {self.vtype}"
141+
)
130142
if self.space.empty():
131143
return
132144
elif self.space.discrete.initialized:

test/specdb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SpecDB test package

test/specdb/base_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Optional
9+
10+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
11+
from facto.inputgen.utils.config import TensorConfig
12+
from facto.specdb.db import SpecDictDB
13+
from facto.utils.ops import get_op_overload
14+
15+
16+
class BaseSpecDBTest(unittest.TestCase):
17+
"""Base test class for validating all specs in SpecDB using gen_errors."""
18+
19+
def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
20+
"""
21+
Run a single op in SpecDB with a given TensorConfig
22+
23+
This test calls ArgumentTupleGenerator.gen_errors with valid=True, out=False
24+
for a single operation. The operation is tested as a subtest.
25+
"""
26+
print("Testing op: ", op_name)
27+
with self.subTest(op=op_name):
28+
try:
29+
# Get the spec and operation
30+
spec = SpecDictDB[op_name]
31+
op = get_op_overload(op_name)
32+
generator = ArgumentTupleGenerator(spec, config)
33+
except Exception as e:
34+
# If we can't resolve the operation or there's another issue,
35+
# fail this subtest with a descriptive message
36+
self.fail(f"Failed to test operation {op_name}: {e}")
37+
38+
try:
39+
errors = list(
40+
generator.gen_errors(op, valid=True, out=False, verbose=True)
41+
)
42+
except Exception as e:
43+
self.fail(f"Failed while testing operation {op_name}: {e}")
44+
45+
if len(errors) > 0:
46+
self.fail(
47+
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
48+
)
49+
50+
def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
51+
"""
52+
Run all ops in SpecDB with a given TensorConfig
53+
54+
This test iterates through all operations in SpecDB and calls
55+
ArgumentTupleGenerator.gen_errors with valid=True, out=False
56+
for each operation. Each operation is tested as a subtest.
57+
"""
58+
# Get all operation names from SpecDB
59+
op_names = list(SpecDictDB.keys())
60+
61+
for op_name in op_names:
62+
if op_name in skip_ops:
63+
continue
64+
self._run_op(op_name, config=config)

test/specdb/test_specdb.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

test/specdb/test_specdb_cpu.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from base_test import BaseSpecDBTest
10+
11+
12+
class TestSpecDBOperationsCPU(BaseSpecDBTest):
13+
"""Test class for validating all specs in SpecDB using gen_errors on CPU."""
14+
15+
def test_all_ops_cpu(self):
16+
skip_ops = [
17+
"_native_batch_norm_legit_no_training.default",
18+
"addmm.default",
19+
"arange.default",
20+
"arange.start_step",
21+
"constant_pad_nd.default",
22+
"split_with_sizes_copy.default",
23+
]
24+
25+
self._run_all_ops(skip_ops=skip_ops)
26+
27+
28+
if __name__ == "__main__":
29+
unittest.main()

test/specdb/test_specdb_mps.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from base_test import BaseSpecDBTest
12+
from facto.inputgen.utils.config import TensorConfig
13+
14+
15+
class TestSpecDBOperationsMPS(BaseSpecDBTest):
16+
"""Test class for validating all specs in SpecDB using gen_errors on MPS."""
17+
18+
def test_all_ops_mps(self):
19+
skip_ops = [
20+
# Calibrate specs (cpu not passing either):
21+
"addmm.default",
22+
"arange.default",
23+
"arange.start_step",
24+
"constant_pad_nd.default",
25+
"split_with_sizes_copy.default",
26+
# https://github.com/pytorch/pytorch/issues/160208
27+
"add.Tensor",
28+
"add.Scalar",
29+
"rsub.Scalar",
30+
"sub.Tensor",
31+
"sub.Scalar",
32+
# crash: https://github.com/pytorch/pytorch/issues/154887
33+
"_native_batch_norm_legit_no_training.default",
34+
# not implemented
35+
"_pdist_forward.default",
36+
# impl: clamp tensor number of dims must not be greater than that of input tensor
37+
"clamp.Tensor",
38+
# crash: https://github.com/pytorch/pytorch/issues/154881
39+
"cumsum.default",
40+
# sparse_grad not supported in MPS yet
41+
"gather.default",
42+
# Dimension specified as -1 but tensor has no dimensions
43+
"index_select.default",
44+
# crash: https://github.com/pytorch/pytorch/issues/154882
45+
"max_pool2d_with_indices.default",
46+
# On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890
47+
# https://github.com/pytorch/pytorch/issues/154890
48+
"topk.default",
49+
# var_mps: reduction dim must be in the range of input shape
50+
"var.correction",
51+
"var.dim",
52+
]
53+
54+
config = TensorConfig(device="mps", disallow_dtypes=[torch.float64])
55+
self._run_all_ops(config=config, skip_ops=skip_ops)
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)