diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 9b50b469627..52b64dc1581 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -43,7 +43,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N cp.Size.Ge(lambda deps, r, d: 1), cp.Size.Le(lambda deps, r, d: 2**9), ] - case "sigmoid.default" | "rsqrt.default": + case "sigmoid.default": additional_tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), @@ -52,6 +52,17 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N cp.Value.Le(lambda deps, dtype, struct: 2), ] ) + case "rsqrt.default": + additional_tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float]), + cp.Rank.Le(lambda deps: 2**2), + cp.Value.Gt( + lambda deps, dtype, struct: 0 + ), # only generate real numbers + cp.Value.Le(lambda deps, dtype, struct: 2**2), + ] + ) case "mean.dim": additional_tensor_constraints.extend( [