diff --git a/facto/inputgen/argtuple/gen.py b/facto/inputgen/argtuple/gen.py index ac180cf..b139753 100644 --- a/facto/inputgen/argtuple/gen.py +++ b/facto/inputgen/argtuple/gen.py @@ -53,6 +53,20 @@ def _apply_constraints_to_arg(self, arg, config: TensorConfig): size_constraint = cp.Size.Ge(lambda deps, r, d: 1) modified_arg.constraints = modified_arg.constraints + [size_constraint] + # Add dtype constraints for tensor arguments when dtypes are disallowed + if config.disallow_dtypes: + if arg.type.is_tensor(): + dtype_constraint = cp.Dtype.NotIn(lambda deps: config.disallow_dtypes) + modified_arg.constraints = modified_arg.constraints + [dtype_constraint] + elif arg.type.is_tensor_list(): + dtype_constraint = cp.Dtype.NotIn( + lambda deps, length, ix: config.disallow_dtypes + ) + modified_arg.constraints = modified_arg.constraints + [dtype_constraint] + elif arg.type.is_scalar_type(): + dtype_constraint = cp.Value.NotIn(lambda deps: config.disallow_dtypes) + modified_arg.constraints = modified_arg.constraints + [dtype_constraint] + return modified_arg def gen_tuple( @@ -85,7 +99,7 @@ def gen( yield self.gen_tuple(meta_tuple, out=out) def gen_errors( - self, op, *, valid: bool = True, out: bool = False + self, op, *, valid: bool = True, out: bool = False, verbose: bool = False ) -> Generator[ Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any ]: @@ -105,7 +119,11 @@ def gen_errors( Yields: Tuples of (posargs, inkwargs, outargs) that don't behave as expected """ - for posargs, inkwargs, outargs in self.gen(valid=valid, out=out): + + engine = MetaArgTupleEngine(self._modified_spec, out=out) + for meta_tuple in engine.gen(valid=valid): + posargs, inkwargs, outargs = self.gen_tuple(meta_tuple, out=out) + try: # Try to execute the operation with the generated inputs if out: @@ -121,12 +139,18 @@ def gen_errors( continue else: # When valid=False, we expect failure, so success IS a bug + if verbose: + print(f"Unexpected success:") + print(op.__name__, str([str(x) for x in meta_tuple])) yield posargs, inkwargs, outargs - except Exception: + except Exception as e: # If execution fails: if valid: # When valid=True, we expect success, so failure IS a bug + if verbose: + print(op.__name__, str([str(x) for x in meta_tuple])) + print(f"Exception occurred: {e}") yield posargs, inkwargs, outargs else: # When valid=False, we expect failure, so this is NOT a bug diff --git a/facto/inputgen/utils/config.py b/facto/inputgen/utils/config.py index bb16f84..2e80cc6 100644 --- a/facto/inputgen/utils/config.py +++ b/facto/inputgen/utils/config.py @@ -12,11 +12,13 @@ class Condition(str, Enum): ALLOW_TRANSPOSED = "transposed" ALLOW_PERMUTED = "permuted" ALLOW_STRIDED = "strided" + DISALLOW_DTYPES = "disallow_dtypes" class TensorConfig: - def __init__(self, device="cpu", **conditions): + def __init__(self, device="cpu", disallow_dtypes=None, **conditions): self.device = device + self.disallow_dtypes = disallow_dtypes or [] self.conditions = {condition: False for condition in Condition} for condition, value in conditions.items(): if condition in self.conditions: @@ -26,6 +28,10 @@ def __init__(self, device="cpu", **conditions): def is_allowed(self, condition: Condition) -> bool: return self.conditions.get(condition, False) + def is_dtype_disallowed(self, dtype) -> bool: + """Check if a given dtype is in the disallow list.""" + return dtype in self.disallow_dtypes + def set_probability(self, probability: float) -> "TensorConfig": self.probability = probability return self diff --git a/facto/inputgen/variable/solve.py b/facto/inputgen/variable/solve.py index 6435017..9430b9d 100644 --- a/facto/inputgen/variable/solve.py +++ b/facto/inputgen/variable/solve.py @@ -35,7 +35,7 @@ def __init__(self, vtype: type): def Eq(self, v: Any) -> None: if invalid_vtype(self.vtype, v): - raise TypeError("Variable type mismatch") + raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}") if self.space.empty(): return if self.space.contains(v): @@ -45,7 +45,7 @@ def Eq(self, v: Any) -> None: def Ne(self, v: Any) -> None: if invalid_vtype(self.vtype, v): - raise TypeError("Variable type mismatch") + raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}") if self.space.empty(): return self.space.remove(v) @@ -53,7 +53,9 @@ def Ne(self, v: Any) -> None: def In(self, values: List[Any]) -> None: for v in values: if invalid_vtype(self.vtype, v): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {v} is not of type {self.vtype}" + ) if self.space.empty(): return self.space.discrete = Discrete( @@ -63,7 +65,9 @@ def In(self, values: List[Any]) -> None: def NotIn(self, values: List[Any]) -> None: for v in values: if invalid_vtype(self.vtype, v): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {v} is not of type {self.vtype}" + ) if self.space.empty(): return for v in values: @@ -73,7 +77,9 @@ def Le(self, upper: Union[bool, int, float]) -> None: if self.vtype not in [bool, int, float]: raise Exception(f"Le is not valid constraint on {self.vtype}") if invalid_vtype(self.vtype, upper): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {upper} is not of type {self.vtype}" + ) if self.space.empty(): return elif self.space.discrete.initialized: @@ -90,7 +96,9 @@ def Lt(self, upper: Union[bool, int, float]) -> None: if self.vtype not in [bool, int, float]: raise Exception(f"Lt is not valid constraint on {self.vtype}") if invalid_vtype(self.vtype, upper): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {upper} is not of type {self.vtype}" + ) if self.space.empty(): return elif self.space.discrete.initialized: @@ -109,7 +117,9 @@ def Ge(self, lower: Union[bool, int, float]) -> None: if self.vtype not in [bool, int, float]: raise Exception(f"Ge is not valid constraint on {self.vtype}") if invalid_vtype(self.vtype, lower): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {lower} is not of type {self.vtype}" + ) if self.space.empty(): return elif self.space.discrete.initialized: @@ -126,7 +136,9 @@ def Gt(self, lower: Union[bool, int, float]) -> None: if self.vtype not in [bool, int, float]: raise Exception(f"Gt is not valid constraint on {self.vtype}") if invalid_vtype(self.vtype, lower): - raise TypeError("Variable type mismatch") + raise TypeError( + f"Variable type mismatch: {lower} is not of type {self.vtype}" + ) if self.space.empty(): return elif self.space.discrete.initialized: diff --git a/test/specdb/__init__.py b/test/specdb/__init__.py new file mode 100644 index 0000000..45742ae --- /dev/null +++ b/test/specdb/__init__.py @@ -0,0 +1 @@ +# SpecDB test package diff --git a/test/specdb/base_test.py b/test/specdb/base_test.py new file mode 100644 index 0000000..55c8230 --- /dev/null +++ b/test/specdb/base_test.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Optional + +from facto.inputgen.argtuple.gen import ArgumentTupleGenerator +from facto.inputgen.utils.config import TensorConfig +from facto.specdb.db import SpecDictDB +from facto.utils.ops import get_op_overload + + +class BaseSpecDBTest(unittest.TestCase): + """Base test class for validating all specs in SpecDB using gen_errors.""" + + def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None): + """ + Run a single op in SpecDB with a given TensorConfig + + This test calls ArgumentTupleGenerator.gen_errors with valid=True, out=False + for a single operation. The operation is tested as a subtest. + """ + print("Testing op: ", op_name) + with self.subTest(op=op_name): + try: + # Get the spec and operation + spec = SpecDictDB[op_name] + op = get_op_overload(op_name) + generator = ArgumentTupleGenerator(spec, config) + except Exception as e: + # If we can't resolve the operation or there's another issue, + # fail this subtest with a descriptive message + self.fail(f"Failed to test operation {op_name}: {e}") + + try: + errors = list( + generator.gen_errors(op, valid=True, out=False, verbose=True) + ) + except Exception as e: + self.fail(f"Failed while testing operation {op_name}: {e}") + + if len(errors) > 0: + self.fail( + f"Found {len(errors)} errors for {op_name} with valid=True, out=False" + ) + + def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]): + """ + Run all ops in SpecDB with a given TensorConfig + + This test iterates through all operations in SpecDB and calls + ArgumentTupleGenerator.gen_errors with valid=True, out=False + for each operation. Each operation is tested as a subtest. + """ + # Get all operation names from SpecDB + op_names = list(SpecDictDB.keys()) + + for op_name in op_names: + if op_name in skip_ops: + continue + self._run_op(op_name, config=config) diff --git a/test/specdb/test_specdb.py b/test/specdb/test_specdb.py deleted file mode 100644 index 66d86a7..0000000 --- a/test/specdb/test_specdb.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from facto.inputgen.argtuple.gen import ArgumentTupleGenerator -from facto.specdb.db import SpecDictDB -from facto.utils.ops import get_op_overload - - -class TestSpecDBOperations(unittest.TestCase): - """Test class for validating all specs in SpecDB using gen_errors.""" - - def test_all_ops(self): - """ - Test all ops in SpecDB. - - This test iterates through all operations in SpecDB and calls - ArgumentTupleGenerator.gen_errors with valid=True, out=False - for each operation. Each operation is tested as a subtest. - """ - # Get all operation names from SpecDB - op_names = list(SpecDictDB.keys()) - - skip_ops = [ - "_native_batch_norm_legit_no_training.default", - "addmm.default", - "arange.default", - "arange.start_step", - "constant_pad_nd.default", - "split_with_sizes_copy.default", - ] - - for op_name in op_names: - if op_name in skip_ops: - continue - with self.subTest(op=op_name): - try: - # Get the spec and operation - spec = SpecDictDB[op_name] - op = get_op_overload(op_name) - generator = ArgumentTupleGenerator(spec) - except Exception as e: - # If we can't resolve the operation or there's another issue, - # fail this subtest with a descriptive message - self.fail(f"Failed to test operation {op_name}: {e}") - - try: - errors = list(generator.gen_errors(op, valid=True, out=False)) - except Exception as e: - self.fail(f"Failed while testing operation {op_name}: {e}") - - if len(errors) > 0: - self.fail( - f"Found {len(errors)} errors for {op_name} with valid=True, out=False" - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/specdb/test_specdb_cpu.py b/test/specdb/test_specdb_cpu.py new file mode 100644 index 0000000..1d93654 --- /dev/null +++ b/test/specdb/test_specdb_cpu.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from base_test import BaseSpecDBTest + + +class TestSpecDBOperationsCPU(BaseSpecDBTest): + """Test class for validating all specs in SpecDB using gen_errors on CPU.""" + + def test_all_ops_cpu(self): + skip_ops = [ + "_native_batch_norm_legit_no_training.default", + "addmm.default", + "arange.default", + "arange.start_step", + "constant_pad_nd.default", + "split_with_sizes_copy.default", + ] + + self._run_all_ops(skip_ops=skip_ops) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/specdb/test_specdb_mps.py b/test/specdb/test_specdb_mps.py new file mode 100644 index 0000000..60de098 --- /dev/null +++ b/test/specdb/test_specdb_mps.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from base_test import BaseSpecDBTest +from facto.inputgen.utils.config import TensorConfig + + +class TestSpecDBOperationsMPS(BaseSpecDBTest): + """Test class for validating all specs in SpecDB using gen_errors on MPS.""" + + def test_all_ops_mps(self): + skip_ops = [ + # Calibrate specs (cpu not passing either): + "addmm.default", + "arange.default", + "arange.start_step", + "constant_pad_nd.default", + "split_with_sizes_copy.default", + # https://github.com/pytorch/pytorch/issues/160208 + "add.Tensor", + "add.Scalar", + "rsub.Scalar", + "sub.Tensor", + "sub.Scalar", + # crash: https://github.com/pytorch/pytorch/issues/154887 + "_native_batch_norm_legit_no_training.default", + # not implemented + "_pdist_forward.default", + # impl: clamp tensor number of dims must not be greater than that of input tensor + "clamp.Tensor", + # crash: https://github.com/pytorch/pytorch/issues/154881 + "cumsum.default", + # sparse_grad not supported in MPS yet + "gather.default", + # Dimension specified as -1 but tensor has no dimensions + "index_select.default", + # crash: https://github.com/pytorch/pytorch/issues/154882 + "max_pool2d_with_indices.default", + # On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890 + # https://github.com/pytorch/pytorch/issues/154890 + "topk.default", + # var_mps: reduction dim must be in the range of input shape + "var.correction", + "var.dim", + ] + + config = TensorConfig(device="mps", disallow_dtypes=[torch.float64]) + self._run_all_ops(config=config, skip_ops=skip_ops) + + +if __name__ == "__main__": + unittest.main()