diff --git a/examples/cadence/operators/facto_util.py b/examples/cadence/operators/facto_util.py index 304b1c7e726..5e6a58ce9f4 100644 --- a/examples/cadence/operators/facto_util.py +++ b/examples/cadence/operators/facto_util.py @@ -22,7 +22,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), - cp.Rank.Le(lambda deps: 2**3), + cp.Rank.Le(lambda deps: 2**2), + cp.Value.Ge(lambda deps, dtype, struct: -2), + cp.Value.Le(lambda deps, dtype, struct: 2), + ] + ) + case "mean.dim": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float]), + cp.Rank.Le(lambda deps: 2**2), ] ) case "exp.default": @@ -86,8 +95,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s cp.Value.Le(lambda deps, dtype: 2), ] ) + elif in_spec.type.is_scalar_type(): + spec.inspec[index].constraints.extend( + [ + cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)), + ] + ) elif in_spec.type.is_tensor(): spec.inspec[index].constraints.extend(tensor_constraints) + elif in_spec.type.is_dim_list(): + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Optional.Eq(lambda deps: False), + ] + ) + elif in_spec.type.is_bool(): + spec.inspec[index].constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.bool]), + ] + ) return [ (posargs, inkwargs) diff --git a/examples/cadence/operators/test_g3_ops.py b/examples/cadence/operators/test_g3_ops.py index 158e13d389f..58433cc739e 100644 --- a/examples/cadence/operators/test_g3_ops.py +++ b/examples/cadence/operators/test_g3_ops.py @@ -259,6 +259,35 @@ def test_g3__softmax_out( self.run_and_verify(model, (inputs,)) + # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. + @parameterized.expand([*facto_util.facto_testcase_gen("mean.dim")]) + def test_g3_mean_dim_out( + self, + posargs: List[int], + inkwargs: OrderedDict[str, str], + ) -> None: + class Meandim(nn.Module): + def forward( + self, + x: torch.Tensor, + dim_list: Tuple[int], + keepdim: bool, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + return torch.ops.aten.mean.dim( + x, + dim_list, + keepdim, + dtype=dtype, + ) + + model = Meandim() + + self.run_and_verify( + model, + inputs=tuple(posargs), + ) + if __name__ == "__main__": unittest.main()