Skip to content
190 changes: 147 additions & 43 deletions backends/cadence/utils/facto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.specs.model import ConstraintProducer as cp
from facto.inputgen.utils.random_manager import seeded_random_manager as rm
from facto.inputgen.variable.type import ScalarDtype
from facto.specdb.db import SpecDictDB

Expand All @@ -26,6 +27,33 @@
_shape_cache: dict[str, list[int]] = {}


def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]:
"""
Generate valid permutations using only positive dimension indices.
This is required for Cadence/Xtensa kernels that don't support negative indexing.

Args:
tensor: Input tensor to generate permutations for
length: Number of dimensions in the permutation (must equal tensor.dim())

Returns:
Set of valid permutation tuples containing only positive indices [0, rank-1]
"""
if length > tensor.dim():
return set()

n = tensor.dim()
pool = list(range(n))

# Generate multiple valid permutations (only positive indices)
permutations: set[tuple[int, ...]] = set()
for _ in range(3): # Generate 3 different permutations for diversity
perm = tuple(rm.get_random().sample(pool, length))
permutations.add(perm)

return permutations


def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
# Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
import random
Expand Down Expand Up @@ -161,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
if index == 0: # condition
tensor_constraints = [
cp.Dtype.In(lambda deps: [torch.bool]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Value.Ge(lambda deps, dtype, struct: 0),
cp.Value.Le(lambda deps, dtype, struct: 1),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
max_size_constraint,
]
elif index == 1: # input tensor(a)
tensor_constraints = [
cp.Dtype.In(
lambda deps: [
torch.int8,
torch.int16,
torch.uint8,
torch.uint16,
torch.int32,
torch.float32,
]
),
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.In(
lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)
),
max_size_constraint,
]
else: # input tensor(b)
tensor_constraints = [
cp.Dtype.In(
lambda deps: [
torch.int8,
torch.int16,
torch.uint8,
torch.uint16,
torch.int32,
torch.float32,
]
),
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Dtype.Eq(lambda deps: deps[1].dtype),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.In(
lambda deps, r, d: fn.broadcast_with(
fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d
)
),
max_size_constraint,
]
case "embedding.default":
Expand Down Expand Up @@ -248,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
# Avoid NaN/Inf values that expose clamp NaN handling bugs
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
]
)
case "rsqrt.default":
Expand Down Expand Up @@ -323,12 +344,15 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
]
)
case "constant_pad_nd.default":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
tensor_constraints = [
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors)
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3
]
case "avg_pool2d.default":
tensor_constraints.extend(
[
Expand All @@ -344,14 +368,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
]
)
case "div.Tensor":
tensor_constraints.extend(
[
cp.Value.Ne(lambda deps, dtype, struct: 0),
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
if index == 1: # Only apply zero-prevention to divisor
tensor_constraints.extend(
[
cp.Value.Ne(
lambda deps, dtype, struct: 0
), # Prevent division by zero
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
else:
tensor_constraints.extend(
[
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
case "pow.Tensor_Scalar":
tensor_constraints.extend(
[
Expand All @@ -373,6 +408,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Value.Ne(
lambda deps, dtype, struct: 0
), # Prevent division by zero
cp.Rank.Ge(lambda deps: 1),
cp.Rank.Eq(lambda deps: deps[0].dim()),
cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)),
Expand All @@ -389,13 +427,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
cp.Value.Le(lambda deps, dtype, struct: 2**2),
cp.Size.Le(lambda deps, r, d: 2**3),
]
case "leaky_relu.default":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32]),
]
)
case "_softmax.default":
tensor_constraints.extend(
[
cp.Dtype.Eq(lambda deps: torch.float32),
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
case "flip.default":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32]),
]
)
case _:
pass
return tensor_constraints
Expand All @@ -409,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
| "mul.Scalar"
| "div.Scalar"
| "constant_pad_nd.default"
| "clamp.default"
):
return [ScalarDtype.int]
case "full.default":
Expand Down Expand Up @@ -436,11 +487,44 @@ def facto_testcase_gen( # noqa: C901
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
if in_spec.name == "max_val": # hardtanh
# Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None
if op_name == "clamp.default":
if in_spec.name == "min":
# min must always be provided (not None) and bounded, leave room for max
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(lambda deps, dtype: -(2**4)),
cp.Value.Le(
lambda deps, dtype: 2**4 - 2
), # Leave room for max (at least 2 units)
]
)
elif in_spec.name == "max":
# max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded
spec.inspec[index].deps = [0, 1] # deps on input tensor and min
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(
lambda deps, dtype: deps[1] + 2
), # max >= min + 2 (sufficient gap)
cp.Value.Le(lambda deps, dtype: 2**4),
]
)
elif in_spec.name == "max_val": # hardtanh
spec.inspec[index].deps = [0, 1]
spec.inspec[index].constraints.extend(
[cp.Value.Ge(lambda deps, _: deps[1])]
)
elif in_spec.name == "negative_slope" and op_name == "leaky_relu.default":
# For leaky_relu, negative_slope should be in typical range (0, 1]
spec.inspec[index].constraints.extend(
[
cp.Value.Gt(lambda deps, dtype: 0),
cp.Value.Le(lambda deps, dtype: 1.0),
]
)
else:
spec.inspec[index].constraints.extend(
[
Expand All @@ -465,12 +549,32 @@ def facto_testcase_gen( # noqa: C901
apply_tensor_contraints(op_name, index)
)
elif in_spec.type.is_dim_list():
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Optional.Eq(lambda deps: False),
]
)
# Special handling for permute_copy.default to ensure valid permutation
if op_name == "permute_copy.default":
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Length.Eq(
lambda deps: deps[0].dim()
), # Must be a complete permutation
cp.Optional.Eq(lambda deps: False),
# Generate valid permutations using only positive indices
# Cadence/Xtensa hardware kernels do not support negative dimension indices
cp.Value.Gen(
lambda deps, length: (
_positive_valid_dim_list(deps[0], length),
fn.invalid_dim_list(deps[0], length),
)
),
]
)
else:
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Optional.Eq(lambda deps: False),
]
)
elif in_spec.type.is_bool():
spec.inspec[index].constraints.extend(
[
Expand Down
Loading