|
18 | 18 |
|
19 | 19 | def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None: |
20 | 20 | match op_name: |
21 | | - case ( |
22 | | - "sigmoid.default" |
23 | | - | "_softmax.default" |
24 | | - | "rsqrt.default" |
25 | | - | "exp.default" |
26 | | - | "mul.Tensor" |
27 | | - | "div.Tensor" |
28 | | - ): |
| 21 | + case "sigmoid.default" | "rsqrt.default": |
29 | 22 | tensor_constraints.extend( |
30 | 23 | [ |
31 | 24 | cp.Dtype.In(lambda deps: [torch.float]), |
32 | | - cp.Size.Le(lambda deps, r, d: 2), |
33 | | - cp.Rank.Le(lambda deps: 2), |
| 25 | + cp.Rank.Le(lambda deps: 2**3), |
34 | 26 | ] |
35 | 27 | ) |
36 | | - case ( |
37 | | - "add.Tensor" |
38 | | - | "sub.Tensor" |
39 | | - | "add.Scalar" |
40 | | - | "sub.Scalar" |
41 | | - | "mul.Scalar" |
42 | | - | "div.Scalar" |
43 | | - ): |
| 28 | + case "exp.default": |
44 | 29 | tensor_constraints.extend( |
45 | 30 | [ |
46 | | - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
47 | | - cp.Size.Le(lambda deps, r, d: 2), |
48 | | - cp.Rank.Le(lambda deps: 2), |
49 | | - ] |
50 | | - ) |
51 | | - case "native_layer_norm.default": |
52 | | - tensor_constraints.extend( |
53 | | - [ |
54 | | - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
55 | | - cp.Size.Le(lambda deps, r, d: 2**4), |
56 | | - cp.Rank.Le(lambda deps: 2**4), |
| 31 | + cp.Rank.Le(lambda deps: 2**3), |
| 32 | + cp.Value.Ge(lambda deps, dtype, struct: -(2**2)), |
| 33 | + cp.Value.Le(lambda deps, dtype, struct: 2**2), |
57 | 34 | ] |
58 | 35 | ) |
59 | 36 | case _: |
60 | 37 | tensor_constraints.extend( |
61 | 38 | [ |
62 | | - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
63 | | - cp.Size.Le(lambda deps, r, d: 2), |
64 | | - cp.Rank.Le(lambda deps: 2), |
| 39 | + cp.Rank.Le(lambda deps: 2**2), |
65 | 40 | ] |
66 | 41 | ) |
67 | 42 | tensor_constraints.extend( |
68 | 43 | [ |
69 | | - cp.Value.Ge(lambda deps, dtype, struct: -(2**8)), |
70 | | - cp.Value.Le(lambda deps, dtype, struct: 2**8), |
| 44 | + cp.Dtype.In(lambda deps: [torch.int, torch.float]), |
| 45 | + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), |
| 46 | + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), |
| 47 | + cp.Value.Le(lambda deps, dtype, struct: 2**4), |
71 | 48 | cp.Rank.Ge(lambda deps: 1), |
72 | 49 | cp.Size.Ge(lambda deps, r, d: 1), |
| 50 | + cp.Size.Le(lambda deps, r, d: 2**9), |
73 | 51 | ] |
74 | 52 | ) |
75 | 53 |
|
|
0 commit comments