|  | 
|  | 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | 
|  | 2 | + | 
|  | 3 | +# pyre-strict | 
|  | 4 | + | 
|  | 5 | +import copy | 
|  | 6 | +from typing import List, OrderedDict, Tuple | 
|  | 7 | + | 
|  | 8 | +import torch | 
|  | 9 | +from inputgen.argtuple.gen import ArgumentTupleGenerator | 
|  | 10 | +from inputgen.specs.model import ConstraintProducer as cp | 
|  | 11 | +from inputgen.utils.random_manager import random_manager | 
|  | 12 | +from inputgen.variable.type import ScalarDtype | 
|  | 13 | +from specdb.db import SpecDictDB | 
|  | 14 | + | 
|  | 15 | +# seed to generate identical cases every run to reproduce from bisect | 
|  | 16 | +random_manager.seed(1729) | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None: | 
|  | 20 | +    match op_name: | 
|  | 21 | +        case ( | 
|  | 22 | +            "sigmoid.default" | 
|  | 23 | +            | "_softmax.default" | 
|  | 24 | +            | "rsqrt.default" | 
|  | 25 | +            | "exp.default" | 
|  | 26 | +            | "mul.Tensor" | 
|  | 27 | +            | "div.Tensor" | 
|  | 28 | +        ): | 
|  | 29 | +            tensor_constraints.append( | 
|  | 30 | +                cp.Dtype.In(lambda deps: [torch.float]), | 
|  | 31 | +            ) | 
|  | 32 | +        case ( | 
|  | 33 | +            "add.Tensor" | 
|  | 34 | +            | "sub.Tensor" | 
|  | 35 | +            | "add.Scalar" | 
|  | 36 | +            | "sub.Scalar" | 
|  | 37 | +            | "mul.Scalar" | 
|  | 38 | +            | "div.Scalar" | 
|  | 39 | +        ): | 
|  | 40 | +            tensor_constraints.append( | 
|  | 41 | +                cp.Dtype.In(lambda deps: [torch.float, torch.int]), | 
|  | 42 | +            ) | 
|  | 43 | +        case _: | 
|  | 44 | +            tensor_constraints.append( | 
|  | 45 | +                cp.Dtype.In(lambda deps: [torch.float, torch.int]), | 
|  | 46 | +            ) | 
|  | 47 | +    tensor_constraints.extend( | 
|  | 48 | +        [ | 
|  | 49 | +            cp.Value.Ge(lambda deps, dtype, struct: -(2**8)), | 
|  | 50 | +            cp.Value.Le(lambda deps, dtype, struct: 2**8), | 
|  | 51 | +            cp.Rank.Ge(lambda deps: 1), | 
|  | 52 | +            cp.Rank.Le(lambda deps: 2**2), | 
|  | 53 | +            cp.Size.Ge(lambda deps, r, d: 1), | 
|  | 54 | +            cp.Size.Le(lambda deps, r, d: 2**2), | 
|  | 55 | +        ] | 
|  | 56 | +    ) | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]: | 
|  | 60 | +    # minimal example to test add.Tensor using FACTO | 
|  | 61 | +    spec = SpecDictDB[op_name] | 
|  | 62 | + | 
|  | 63 | +    for index, in_spec in enumerate(copy.deepcopy(spec.inspec)): | 
|  | 64 | +        if in_spec.type.is_scalar(): | 
|  | 65 | +            if in_spec.name != "alpha": | 
|  | 66 | +                spec.inspec[index].constraints.extend( | 
|  | 67 | +                    [ | 
|  | 68 | +                        cp.Dtype.In(lambda deps: [ScalarDtype.float, ScalarDtype.int]), | 
|  | 69 | +                        cp.Value.Ge(lambda deps, dtype: -(2**8)), | 
|  | 70 | +                        cp.Value.Le(lambda deps, dtype: 2**2), | 
|  | 71 | +                        cp.Size.Ge(lambda deps, r, d: 1), | 
|  | 72 | +                        cp.Size.Le(lambda deps, r, d: 2**2), | 
|  | 73 | +                    ] | 
|  | 74 | +                ) | 
|  | 75 | +            else: | 
|  | 76 | +                spec.inspec[index].constraints.extend( | 
|  | 77 | +                    [ | 
|  | 78 | +                        cp.Value.Gt(lambda deps, dtype: 0), | 
|  | 79 | +                        cp.Value.Le(lambda deps, dtype: 2), | 
|  | 80 | +                    ] | 
|  | 81 | +                ) | 
|  | 82 | +        elif in_spec.type.is_tensor(): | 
|  | 83 | +            tensor_constraints = [] | 
|  | 84 | +            # common tensor constraints | 
|  | 85 | +            apply_tensor_contraints(op_name, tensor_constraints) | 
|  | 86 | +            spec.inspec[index].constraints.extend(tensor_constraints) | 
|  | 87 | + | 
|  | 88 | +    return [ | 
|  | 89 | +        (posargs, inkwargs) | 
|  | 90 | +        for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen() | 
|  | 91 | +    ] | 
0 commit comments