From 83a9bc0bd36b4d3a1f9ab3e1b76a91f8effbc07b Mon Sep 17 00:00:00 2001 From: Zonglin Peng Date: Thu, 29 May 2025 18:50:00 -0700 Subject: [PATCH] fix where self out contraint to make a, b numerical (#11240) Summary: condition should be bool. self other should be numericals Differential Revision: D75644590 --- backends/cadence/utils/facto_util.py | 54 ++++++++++++++++------------ 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 8cd57059244..b896f8a8e89 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -20,8 +20,8 @@ MAX_CASES = 50 -def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None: - additional_tensor_constraints = [ +def apply_tensor_contraints(op_name: str, index: int) -> list[object]: + tensor_constraints = [ cp.Dtype.In(lambda deps: [torch.int, torch.float]), cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), @@ -33,17 +33,28 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N match op_name: case "where.self": - additional_tensor_constraints = [ - cp.Dtype.In(lambda deps: [torch.float, torch.int, torch.bool]), - cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), - cp.Rank.Ge(lambda deps: 1), - cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), - ] + if index == 0: # condition + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.bool]), + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + ] + else: + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float, torch.int]), + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + ] case "sigmoid.default": - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), cp.Rank.Le(lambda deps: 2**2), @@ -52,7 +63,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case "rsqrt.default": - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), cp.Rank.Le(lambda deps: 2**2), @@ -63,14 +74,14 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case "mean.dim": - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float]), cp.Rank.Le(lambda deps: 2**2), ] ) case "exp.default": - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2**3), cp.Value.Ge(lambda deps, dtype, struct: -(2**2)), @@ -78,7 +89,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case "slice_copy.Tensor": - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2), cp.Value.Ge(lambda deps, dtype, struct: 1), @@ -86,12 +97,12 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N ] ) case _: - additional_tensor_constraints.extend( + tensor_constraints.extend( [ cp.Rank.Le(lambda deps: 2**2), ] ) - tensor_constraints.extend(additional_tensor_constraints) + return tensor_constraints def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: @@ -107,9 +118,6 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]: # minimal example to test add.Tensor using FACTO spec = SpecDictDB[op_name] - tensor_constraints = [] - # common tensor constraints - apply_tensor_contraints(op_name, tensor_constraints) for index, in_spec in enumerate(copy.deepcopy(spec.inspec)): if in_spec.type.is_scalar(): @@ -142,7 +150,9 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s ] ) elif in_spec.type.is_tensor(): - spec.inspec[index].constraints.extend(tensor_constraints) + spec.inspec[index].constraints.extend( + apply_tensor_contraints(op_name, index) + ) elif in_spec.type.is_dim_list(): spec.inspec[index].constraints.extend( [