Skip to content

Test ops on MPS backend #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/manuelcandales/18/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions facto/inputgen/argtuple/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def _apply_constraints_to_arg(self, arg, config: TensorConfig):
size_constraint = cp.Size.Ge(lambda deps, r, d: 1)
modified_arg.constraints = modified_arg.constraints + [size_constraint]

# Add dtype constraints for tensor arguments when dtypes are disallowed
if config.disallow_dtypes:
if arg.type.is_tensor():
dtype_constraint = cp.Dtype.NotIn(lambda deps: config.disallow_dtypes)
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]
elif arg.type.is_tensor_list():
dtype_constraint = cp.Dtype.NotIn(
lambda deps, length, ix: config.disallow_dtypes
)
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]
elif arg.type.is_scalar_type():
dtype_constraint = cp.Value.NotIn(lambda deps: config.disallow_dtypes)
modified_arg.constraints = modified_arg.constraints + [dtype_constraint]

return modified_arg

def gen_tuple(
Expand Down Expand Up @@ -85,7 +99,7 @@ def gen(
yield self.gen_tuple(meta_tuple, out=out)

def gen_errors(
self, op, *, valid: bool = True, out: bool = False
self, op, *, valid: bool = True, out: bool = False, verbose: bool = False
) -> Generator[
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
]:
Expand All @@ -105,7 +119,11 @@ def gen_errors(
Yields:
Tuples of (posargs, inkwargs, outargs) that don't behave as expected
"""
for posargs, inkwargs, outargs in self.gen(valid=valid, out=out):

engine = MetaArgTupleEngine(self._modified_spec, out=out)
for meta_tuple in engine.gen(valid=valid):
posargs, inkwargs, outargs = self.gen_tuple(meta_tuple, out=out)

try:
# Try to execute the operation with the generated inputs
if out:
Expand All @@ -121,12 +139,18 @@ def gen_errors(
continue
else:
# When valid=False, we expect failure, so success IS a bug
if verbose:
print(f"Unexpected success:")
print(op.__name__, str([str(x) for x in meta_tuple]))
yield posargs, inkwargs, outargs

except Exception:
except Exception as e:
# If execution fails:
if valid:
# When valid=True, we expect success, so failure IS a bug
if verbose:
print(op.__name__, str([str(x) for x in meta_tuple]))
print(f"Exception occurred: {e}")
yield posargs, inkwargs, outargs
else:
# When valid=False, we expect failure, so this is NOT a bug
Expand Down
8 changes: 7 additions & 1 deletion facto/inputgen/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ class Condition(str, Enum):
ALLOW_TRANSPOSED = "transposed"
ALLOW_PERMUTED = "permuted"
ALLOW_STRIDED = "strided"
DISALLOW_DTYPES = "disallow_dtypes"


class TensorConfig:
def __init__(self, device="cpu", **conditions):
def __init__(self, device="cpu", disallow_dtypes=None, **conditions):
self.device = device
self.disallow_dtypes = disallow_dtypes or []
self.conditions = {condition: False for condition in Condition}
for condition, value in conditions.items():
if condition in self.conditions:
Expand All @@ -26,6 +28,10 @@ def __init__(self, device="cpu", **conditions):
def is_allowed(self, condition: Condition) -> bool:
return self.conditions.get(condition, False)

def is_dtype_disallowed(self, dtype) -> bool:
"""Check if a given dtype is in the disallow list."""
return dtype in self.disallow_dtypes

def set_probability(self, probability: float) -> "TensorConfig":
self.probability = probability
return self
28 changes: 20 additions & 8 deletions facto/inputgen/variable/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, vtype: type):

def Eq(self, v: Any) -> None:
if invalid_vtype(self.vtype, v):
raise TypeError("Variable type mismatch")
raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}")
if self.space.empty():
return
if self.space.contains(v):
Expand All @@ -45,15 +45,17 @@ def Eq(self, v: Any) -> None:

def Ne(self, v: Any) -> None:
if invalid_vtype(self.vtype, v):
raise TypeError("Variable type mismatch")
raise TypeError(f"Variable type mismatch: {v} is not of type {self.vtype}")
if self.space.empty():
return
self.space.remove(v)

def In(self, values: List[Any]) -> None:
for v in values:
if invalid_vtype(self.vtype, v):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {v} is not of type {self.vtype}"
)
if self.space.empty():
return
self.space.discrete = Discrete(
Expand All @@ -63,7 +65,9 @@ def In(self, values: List[Any]) -> None:
def NotIn(self, values: List[Any]) -> None:
for v in values:
if invalid_vtype(self.vtype, v):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {v} is not of type {self.vtype}"
)
if self.space.empty():
return
for v in values:
Expand All @@ -73,7 +77,9 @@ def Le(self, upper: Union[bool, int, float]) -> None:
if self.vtype not in [bool, int, float]:
raise Exception(f"Le is not valid constraint on {self.vtype}")
if invalid_vtype(self.vtype, upper):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {upper} is not of type {self.vtype}"
)
if self.space.empty():
return
elif self.space.discrete.initialized:
Expand All @@ -90,7 +96,9 @@ def Lt(self, upper: Union[bool, int, float]) -> None:
if self.vtype not in [bool, int, float]:
raise Exception(f"Lt is not valid constraint on {self.vtype}")
if invalid_vtype(self.vtype, upper):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {upper} is not of type {self.vtype}"
)
if self.space.empty():
return
elif self.space.discrete.initialized:
Expand All @@ -109,7 +117,9 @@ def Ge(self, lower: Union[bool, int, float]) -> None:
if self.vtype not in [bool, int, float]:
raise Exception(f"Ge is not valid constraint on {self.vtype}")
if invalid_vtype(self.vtype, lower):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {lower} is not of type {self.vtype}"
)
if self.space.empty():
return
elif self.space.discrete.initialized:
Expand All @@ -126,7 +136,9 @@ def Gt(self, lower: Union[bool, int, float]) -> None:
if self.vtype not in [bool, int, float]:
raise Exception(f"Gt is not valid constraint on {self.vtype}")
if invalid_vtype(self.vtype, lower):
raise TypeError("Variable type mismatch")
raise TypeError(
f"Variable type mismatch: {lower} is not of type {self.vtype}"
)
if self.space.empty():
return
elif self.space.discrete.initialized:
Expand Down
1 change: 1 addition & 0 deletions test/specdb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SpecDB test package
64 changes: 64 additions & 0 deletions test/specdb/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Optional

from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.utils.config import TensorConfig
from facto.specdb.db import SpecDictDB
from facto.utils.ops import get_op_overload


class BaseSpecDBTest(unittest.TestCase):
"""Base test class for validating all specs in SpecDB using gen_errors."""

def _run_op(self, op_name: str, *, config: Optional[TensorConfig] = None):
"""
Run a single op in SpecDB with a given TensorConfig

This test calls ArgumentTupleGenerator.gen_errors with valid=True, out=False
for a single operation. The operation is tested as a subtest.
"""
print("Testing op: ", op_name)
with self.subTest(op=op_name):
try:
# Get the spec and operation
spec = SpecDictDB[op_name]
op = get_op_overload(op_name)
generator = ArgumentTupleGenerator(spec, config)
except Exception as e:
# If we can't resolve the operation or there's another issue,
# fail this subtest with a descriptive message
self.fail(f"Failed to test operation {op_name}: {e}")

try:
errors = list(
generator.gen_errors(op, valid=True, out=False, verbose=True)
)
except Exception as e:
self.fail(f"Failed while testing operation {op_name}: {e}")

if len(errors) > 0:
self.fail(
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
)

def _run_all_ops(self, *, config: Optional[TensorConfig] = None, skip_ops=[]):
"""
Run all ops in SpecDB with a given TensorConfig

This test iterates through all operations in SpecDB and calls
ArgumentTupleGenerator.gen_errors with valid=True, out=False
for each operation. Each operation is tested as a subtest.
"""
# Get all operation names from SpecDB
op_names = list(SpecDictDB.keys())

for op_name in op_names:
if op_name in skip_ops:
continue
self._run_op(op_name, config=config)
64 changes: 0 additions & 64 deletions test/specdb/test_specdb.py

This file was deleted.

29 changes: 29 additions & 0 deletions test/specdb/test_specdb_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from base_test import BaseSpecDBTest


class TestSpecDBOperationsCPU(BaseSpecDBTest):
"""Test class for validating all specs in SpecDB using gen_errors on CPU."""

def test_all_ops_cpu(self):
skip_ops = [
"_native_batch_norm_legit_no_training.default",
"addmm.default",
"arange.default",
"arange.start_step",
"constant_pad_nd.default",
"split_with_sizes_copy.default",
]

self._run_all_ops(skip_ops=skip_ops)


if __name__ == "__main__":
unittest.main()
59 changes: 59 additions & 0 deletions test/specdb/test_specdb_mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from base_test import BaseSpecDBTest
from facto.inputgen.utils.config import TensorConfig


class TestSpecDBOperationsMPS(BaseSpecDBTest):
"""Test class for validating all specs in SpecDB using gen_errors on MPS."""

def test_all_ops_mps(self):
skip_ops = [
# Calibrate specs (cpu not passing either):
"addmm.default",
"arange.default",
"arange.start_step",
"constant_pad_nd.default",
"split_with_sizes_copy.default",
# https://github.com/pytorch/pytorch/issues/160208
"add.Tensor",
"add.Scalar",
"rsub.Scalar",
"sub.Tensor",
"sub.Scalar",
# crash: https://github.com/pytorch/pytorch/issues/154887
"_native_batch_norm_legit_no_training.default",
# not implemented
"_pdist_forward.default",
# impl: clamp tensor number of dims must not be greater than that of input tensor
"clamp.Tensor",
# crash: https://github.com/pytorch/pytorch/issues/154881
"cumsum.default",
# sparse_grad not supported in MPS yet
"gather.default",
# Dimension specified as -1 but tensor has no dimensions
"index_select.default",
# crash: https://github.com/pytorch/pytorch/issues/154882
"max_pool2d_with_indices.default",
# On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890
# https://github.com/pytorch/pytorch/issues/154890
"topk.default",
# var_mps: reduction dim must be in the range of input shape
"var.correction",
"var.dim",
]

config = TensorConfig(device="mps", disallow_dtypes=[torch.float64])
self._run_all_ops(config=config, skip_ops=skip_ops)


if __name__ == "__main__":
unittest.main()