Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion examples/cadence/operators/facto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float]),
cp.Rank.Le(lambda deps: 2**3),
cp.Rank.Le(lambda deps: 2**2),
cp.Value.Ge(lambda deps, dtype, struct: -2),
cp.Value.Le(lambda deps, dtype, struct: 2),
]
)
case "mean.dim":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float]),
cp.Rank.Le(lambda deps: 2**2),
]
)
case "exp.default":
Expand Down Expand Up @@ -86,8 +95,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
cp.Value.Le(lambda deps, dtype: 2),
]
)
elif in_spec.type.is_scalar_type():
spec.inspec[index].constraints.extend(
[
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
]
)
elif in_spec.type.is_tensor():
spec.inspec[index].constraints.extend(tensor_constraints)
elif in_spec.type.is_dim_list():
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Optional.Eq(lambda deps: False),
]
)
elif in_spec.type.is_bool():
spec.inspec[index].constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.bool]),
]
)

return [
(posargs, inkwargs)
Expand Down
29 changes: 29 additions & 0 deletions examples/cadence/operators/test_g3_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,35 @@ def test_g3__softmax_out(

self.run_and_verify(model, (inputs,))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("mean.dim")])
def test_g3_mean_dim_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
class Meandim(nn.Module):
def forward(
self,
x: torch.Tensor,
dim_list: Tuple[int],
keepdim: bool,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
return torch.ops.aten.mean.dim(
x,
dim_list,
keepdim,
dtype=dtype,
)

model = Meandim()

self.run_and_verify(
model,
inputs=tuple(posargs),
)


if __name__ == "__main__":
unittest.main()
Loading