diff --git a/examples/cadence/operators/facto_util.py b/examples/cadence/operators/facto_util.py index e9b16f8bf6f..e708796c7b7 100644 --- a/examples/cadence/operators/facto_util.py +++ b/examples/cadence/operators/facto_util.py @@ -26,8 +26,12 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N | "mul.Tensor" | "div.Tensor" ): - tensor_constraints.append( - cp.Dtype.In(lambda deps: [torch.float]), + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float]), + cp.Size.Le(lambda deps, r, d: 2), + cp.Rank.Le(lambda deps: 2), + ] ) case ( "add.Tensor" @@ -37,35 +41,60 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N | "mul.Scalar" | "div.Scalar" ): - tensor_constraints.append( - cp.Dtype.In(lambda deps: [torch.float, torch.int]), + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float, torch.int32]), + cp.Size.Le(lambda deps, r, d: 2), + cp.Rank.Le(lambda deps: 2), + ] + ) + case "native_layer_norm.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float, torch.int32]), + cp.Size.Le(lambda deps, r, d: 2**4), + cp.Rank.Le(lambda deps: 2**4), + ] ) case _: - tensor_constraints.append( - cp.Dtype.In(lambda deps: [torch.float, torch.int]), + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float, torch.int32]), + cp.Size.Le(lambda deps, r, d: 2), + cp.Rank.Le(lambda deps: 2), + ] ) tensor_constraints.extend( [ cp.Value.Ge(lambda deps, dtype, struct: -(2**8)), cp.Value.Le(lambda deps, dtype, struct: 2**8), cp.Rank.Ge(lambda deps: 1), - cp.Rank.Le(lambda deps: 2**2), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**2), ] ) +def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: + match op_name: + case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar": + return [ScalarDtype.int] + case _: + return [ScalarDtype.float, ScalarDtype.int] + + def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]: # minimal example to test add.Tensor using FACTO spec = SpecDictDB[op_name] + tensor_constraints = [] + # common tensor constraints + apply_tensor_contraints(op_name, tensor_constraints) for index, in_spec in enumerate(copy.deepcopy(spec.inspec)): if in_spec.type.is_scalar(): if in_spec.name != "alpha": spec.inspec[index].constraints.extend( [ - cp.Dtype.In(lambda deps: [ScalarDtype.float, ScalarDtype.int]), + cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)), cp.Value.Ge(lambda deps, dtype: -(2**8)), cp.Value.Le(lambda deps, dtype: 2**2), cp.Size.Ge(lambda deps, r, d: 1), @@ -80,9 +109,6 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s ] ) elif in_spec.type.is_tensor(): - tensor_constraints = [] - # common tensor constraints - apply_tensor_contraints(op_name, tensor_constraints) spec.inspec[index].constraints.extend(tensor_constraints) return [ diff --git a/examples/cadence/operators/targets.bzl b/examples/cadence/operators/targets.bzl index a646f0076b4..32dc9061b51 100644 --- a/examples/cadence/operators/targets.bzl +++ b/examples/cadence/operators/targets.bzl @@ -9,6 +9,7 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") TESTS_LIST = [ "add_op", + "g3_ops", "quantized_conv1d_op", "quantized_linear_op", ] @@ -46,5 +47,6 @@ def _define_test_target(test_name): "fbcode//executorch/backends/cadence/aot:ops_registrations", "fbcode//executorch/backends/cadence/aot:export_example", "fbcode//executorch/backends/cadence/aot:compiler", + "fbcode//executorch/examples/cadence/operators:facto_util", ], ) diff --git a/examples/cadence/operators/test_g3_ops.py b/examples/cadence/operators/test_g3_ops.py new file mode 100644 index 00000000000..158e13d389f --- /dev/null +++ b/examples/cadence/operators/test_g3_ops.py @@ -0,0 +1,264 @@ +import unittest +from typing import Any, cast, List, OrderedDict, Tuple + +from executorch.examples.cadence.operators import facto_util + +from parameterized import parameterized + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch +import torch.nn as nn +from executorch.backends.cadence.aot.export_example import export_model + + +class ATenOpTestCases(unittest.TestCase): + def run_and_verify(self, model: nn.Module, inputs: Tuple[Any, ...]) -> None: + model.eval() + export_model( + model, inputs, file_name=self._testMethodName, run_and_compare=False + ) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("add.Tensor")]) + @torch.no_grad() + def test_g3_add_tensor_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class AddTensor(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.add(x, y, alpha=self.alpha) + + model = AddTensor(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("add.Scalar")]) + @torch.no_grad() + def test_aten_add_Scalar_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class AddScalar(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: float): + return torch.add(x, y, alpha=self.alpha) + + inputs = posargs[:-1] # posargs = [x_tensor, y_scalar, alpha_scalar] + alpha = posargs[-1] + model = AddScalar(alpha) + + self.run_and_verify(model, tuple(inputs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("sub.Tensor")]) + @torch.no_grad() + def test_g3_sub_tensor_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class SubTensor(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.sub(x, y, alpha=self.alpha) + + model = SubTensor(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("sub.Scalar")]) + @torch.no_grad() + def test_g3_sub_scalar_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + # Tensor-Scalar subtraction + class SubScalar(torch.nn.Module): + def __init__(self, other): + super().__init__() + self.other = other + + def forward(self, x): + return torch.ops.aten.sub.Scalar(x, self.other) + + inputs = posargs[0] # posargs = [x_tensor, y_scalar, alpha_scalar] + model = SubScalar(posargs[1]) + + self.run_and_verify(model, (inputs,)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("div.Tensor")]) + @torch.no_grad() + def test_g3_div_tensor_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class DivTensor(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.div(x, y + 1) + + model = DivTensor(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("div.Scalar")]) + @torch.no_grad() + def test_g3_div_scalar_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class DivScalar(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.div(x, y + 1) + + model = DivScalar(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("exp.default")]) + @torch.no_grad() + def test_g3_exp_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class Exp(nn.Module): + def forward(self, x: torch.Tensor): + return torch.exp(x) + + model = Exp(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("mul.Tensor")]) + @torch.no_grad() + def test_g3_mul_tensor_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class MulTensor(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x * y + + model = MulTensor(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("mul.Scalar")]) + @torch.no_grad() + def test_g3_mul_scalar_out( + self, + posargs: List[str], + inkwargs: OrderedDict[str, str], + ) -> None: + class MulScalar(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x * y + + model = MulScalar(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("native_layer_norm.default")]) + @torch.no_grad() + def test_g3_native_layer_norm_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + inputs, normalized_shape, weight, bias, _ = posargs + model = nn.LayerNorm(normalized_shape, eps=1e-5) + if weight is not None: + weight = cast(torch.Tensor, weight) + model.weight = nn.Parameter(torch.rand_like(weight)) + if bias is not None: + bias = cast(torch.Tensor, bias) + model.bias = nn.Parameter(torch.rand_like(bias)) + + self.run_and_verify(model, (inputs,)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("neg.default")]) + @torch.no_grad() + def test_g3_neg_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + class Neg(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.neg(x) + + model = Neg(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("rsqrt.default")]) + @torch.no_grad() + def test_g3_rsqrt_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + class Rsqrt(nn.Module): + def forward(self, x: torch.Tensor): + return torch.ops.aten.rsqrt(x) + + model = Rsqrt(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("sigmoid.default")]) + @torch.no_grad() + def test_g3_sigmoid_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + model = nn.Sigmoid(**inkwargs) + + self.run_and_verify(model, tuple(posargs)) + + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("_softmax.default")]) + @torch.no_grad() + def test_g3__softmax_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + inputs, _, _ = posargs + model = nn.Softmax(dim=-1) + + self.run_and_verify(model, (inputs,)) + + +if __name__ == "__main__": + unittest.main()