Skip to content

Commit 14bf790

Browse files
authored
migrate all test_aten_ops to facto
Differential Revision: D79121474 Pull Request resolved: #13483
1 parent e4ddf69 commit 14bf790

File tree

1 file changed

+138
-17
lines changed

1 file changed

+138
-17
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 138 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from functools import lru_cache
1111
from typing import List, OrderedDict, Tuple
1212

13+
import facto.specdb.function as fn
14+
1315
import torch
1416
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
1517
from facto.inputgen.specs.model import ConstraintProducer as cp
@@ -22,21 +24,28 @@
2224

2325
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
2426
tensor_constraints = [
25-
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
26-
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
27+
cp.Dtype.In(
28+
lambda deps: [
29+
torch.int8,
30+
torch.int16,
31+
torch.uint8,
32+
torch.uint16,
33+
torch.float32,
34+
]
35+
),
2736
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
2837
cp.Value.Le(lambda deps, dtype, struct: 2**4),
2938
cp.Rank.Ge(lambda deps: 1),
3039
cp.Size.Ge(lambda deps, r, d: 1),
3140
cp.Size.Le(lambda deps, r, d: 2**9),
41+
cp.Rank.Le(lambda deps: 2**3),
3242
]
3343

3444
match op_name:
3545
case "where.self":
3646
if index == 0: # condition
3747
tensor_constraints = [
3848
cp.Dtype.In(lambda deps: [torch.bool]),
39-
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
4049
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
4150
cp.Value.Le(lambda deps, dtype, struct: 2**4),
4251
cp.Rank.Ge(lambda deps: 1),
@@ -45,28 +54,43 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
4554
]
4655
else:
4756
tensor_constraints = [
48-
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
49-
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
57+
cp.Dtype.In(
58+
lambda deps: [
59+
torch.int8,
60+
torch.int16,
61+
torch.uint8,
62+
torch.uint16,
63+
torch.float32,
64+
]
65+
),
5066
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
5167
cp.Value.Le(lambda deps, dtype, struct: 2**4),
5268
cp.Rank.Ge(lambda deps: 1),
5369
cp.Size.Ge(lambda deps, r, d: 1),
5470
cp.Size.Le(lambda deps, r, d: 2**9),
5571
]
72+
case "embedding.default":
73+
tensor_constraints = [
74+
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
75+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
76+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
77+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
78+
cp.Rank.Ge(lambda deps: 1),
79+
cp.Size.Ge(lambda deps, r, d: 1),
80+
cp.Size.Le(lambda deps, r, d: 2**9),
81+
]
5682
case "sigmoid.default":
5783
tensor_constraints.extend(
5884
[
59-
cp.Dtype.In(lambda deps: [torch.float]),
60-
cp.Rank.Le(lambda deps: 2**2),
85+
cp.Dtype.In(lambda deps: [torch.float32]),
6186
cp.Value.Ge(lambda deps, dtype, struct: -2),
6287
cp.Value.Le(lambda deps, dtype, struct: 2),
6388
]
6489
)
6590
case "rsqrt.default":
6691
tensor_constraints.extend(
6792
[
68-
cp.Dtype.In(lambda deps: [torch.float]),
69-
cp.Rank.Le(lambda deps: 2**2),
93+
cp.Dtype.In(lambda deps: [torch.float32]),
7094
cp.Value.Gt(
7195
lambda deps, dtype, struct: 0
7296
), # only generate real numbers
@@ -76,14 +100,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
76100
case "mean.dim":
77101
tensor_constraints.extend(
78102
[
79-
cp.Dtype.In(lambda deps: [torch.float]),
80-
cp.Rank.Le(lambda deps: 2**2),
103+
cp.Dtype.In(lambda deps: [torch.float32]),
81104
]
82105
)
83106
case "exp.default":
84107
tensor_constraints.extend(
85108
[
86-
cp.Rank.Le(lambda deps: 2**3),
87109
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
88110
cp.Value.Le(lambda deps, dtype, struct: 2**2),
89111
]
@@ -96,26 +118,96 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
96118
cp.Value.Le(lambda deps, dtype, struct: 2),
97119
]
98120
)
99-
case _:
121+
case "constant_pad_nd.default":
100122
tensor_constraints.extend(
101123
[
102-
cp.Rank.Le(lambda deps: 2**2),
124+
cp.Dtype.In(lambda deps: [torch.float32]),
125+
cp.Size.Le(lambda deps, r, d: 2**2),
126+
]
127+
)
128+
case "avg_pool2d.default":
129+
tensor_constraints.extend(
130+
[
131+
cp.Rank.Eq(lambda deps: 4),
132+
]
133+
)
134+
case "bmm.default" | "addmm.default" | "mm.default":
135+
tensor_constraints.extend(
136+
[
137+
cp.Dtype.Eq(lambda deps: torch.float),
138+
cp.Size.Le(lambda deps, r, d: 2**2),
139+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
103140
]
104141
)
142+
case "div.Tensor":
143+
tensor_constraints.extend(
144+
[
145+
cp.Value.Ne(lambda deps, dtype, struct: 0),
146+
]
147+
)
148+
case "div.Tensor_mode" | "minimum.default":
149+
if index == 0:
150+
tensor_constraints = [
151+
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
152+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
153+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
154+
cp.Rank.Ge(lambda deps: 1),
155+
cp.Size.Ge(lambda deps, r, d: 1),
156+
cp.Size.Le(lambda deps, r, d: 2**2),
157+
]
158+
else:
159+
tensor_constraints = [
160+
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
161+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
162+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
163+
cp.Rank.Ge(lambda deps: 1),
164+
cp.Rank.Eq(lambda deps: deps[0].dim()),
165+
cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)),
166+
]
167+
case "_native_batch_norm_legit_no_training.default":
168+
tensor_constraints.extend(
169+
[
170+
cp.Rank.Le(lambda deps: 3),
171+
],
172+
)
173+
case "reciprocal.default":
174+
tensor_constraints = [
175+
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
176+
cp.Value.Le(lambda deps, dtype, struct: 2**2),
177+
cp.Size.Le(lambda deps, r, d: 2**3),
178+
]
179+
case "_softmax.default":
180+
tensor_constraints.extend(
181+
[
182+
cp.Dtype.Eq(lambda deps: torch.float32),
183+
cp.Size.Le(lambda deps, r, d: 2**2),
184+
]
185+
)
186+
case _:
187+
pass
105188
return tensor_constraints
106189

107190

108191
def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
109192
match op_name:
110-
case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar":
193+
case (
194+
"add.Scalar"
195+
| "sub.Scalar"
196+
| "mul.Scalar"
197+
| "div.Scalar"
198+
| "constant_pad_nd.default"
199+
):
200+
return [ScalarDtype.int]
201+
case "full.default":
111202
return [ScalarDtype.int]
112-
113203
case _:
114204
return [ScalarDtype.float, ScalarDtype.int]
115205

116206

117207
@lru_cache(maxsize=None)
118-
def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]:
208+
def facto_testcase_gen( # noqa: C901
209+
op_name: str,
210+
) -> List[Tuple[List[str], OrderedDict[str, str]]]:
119211
# minimal example to test add.Tensor using FACTO
120212
spec = SpecDictDB[op_name]
121213

@@ -149,6 +241,12 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
149241
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
150242
]
151243
)
244+
if in_spec.name == "dtype": # full.default
245+
spec.inspec[index].constraints.extend(
246+
[
247+
cp.Dtype.In(lambda deps: [torch.long, torch.float]),
248+
]
249+
)
152250
elif in_spec.type.is_tensor():
153251
spec.inspec[index].constraints.extend(
154252
apply_tensor_contraints(op_name, index)
@@ -166,6 +264,29 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
166264
cp.Dtype.In(lambda deps: [torch.bool]),
167265
]
168266
)
267+
elif in_spec.type.is_length_list():
268+
spec.inspec[index].constraints.extend(
269+
[
270+
cp.Value.Ge(lambda deps, dtype, struct: 0),
271+
]
272+
)
273+
if op_name == "avg_pool2d.default":
274+
spec.inspec[index].constraints.extend(
275+
[
276+
cp.Length.Eq(lambda deps: 2),
277+
]
278+
)
279+
elif in_spec.type.is_shape():
280+
spec.inspec[index].constraints.extend(
281+
[
282+
cp.Rank.Ge(lambda deps: 1),
283+
cp.Rank.Le(lambda deps: 2**2),
284+
cp.Value.Gt(lambda deps, dtype, struct: 0),
285+
cp.Value.Le(lambda deps, dtype, struct: 2**2),
286+
cp.Size.Ge(lambda deps, r, d: 1),
287+
cp.Size.Le(lambda deps, r, d: 2**2),
288+
]
289+
)
169290

170291
return [
171292
(posargs, inkwargs)

0 commit comments

Comments
 (0)