7
7
import argparse
8
8
import logging
9
9
import sys
10
- from typing import Any , List , OrderedDict , Tuple
10
+ from typing import Any , List , Optional , OrderedDict , Tuple
11
11
12
12
import torch
13
- from executorch .exir .dialects .edge .op .api import get_callable , to_variant
14
13
from facto .inputgen .argtuple .engine import MetaArgTupleEngine
15
14
from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
16
15
from facto .inputgen .argument .engine import MetaArg
17
16
from facto .inputgen .specs .model import Spec
17
+ from facto .inputgen .utils .config import TensorConfig
18
18
from facto .specdb .db import SpecDictDB
19
+ from facto .utils .ops import get_op_overload
19
20
from torch ._ops import OpOverload
20
- from torchgen .model import SchemaKind
21
21
22
22
logging .basicConfig (stream = sys .stderr , level = logging .WARNING )
23
23
@@ -29,32 +29,27 @@ def smt(meta_tuple):
29
29
class SpecRunner :
30
30
def __init__ (
31
31
self ,
32
+ op : OpOverload ,
32
33
spec : Spec ,
33
34
* ,
34
35
valid : bool = True ,
35
36
out : bool = False ,
36
37
devices : Tuple [str ] = ("cpu" ,),
38
+ config : Optional [TensorConfig ] = None ,
37
39
):
38
40
self .spec = spec
39
- self .generator = ArgumentTupleGenerator (self .spec )
41
+ self .config = config
42
+ self .generator = ArgumentTupleGenerator (self .spec , config = config )
40
43
self .valid = valid
41
44
self .out = out
42
- self .op_name = spec . op
43
- self .op = self . get_callable_op ()
45
+ self .op_name = op . __name__
46
+ self .op = op
44
47
self .results = {}
45
48
self .devices = devices
46
49
self .results = {}
47
50
for device in self .devices :
48
51
self .results [device ] = {}
49
52
50
- def get_callable_op (self ):
51
- name = self .spec .op
52
- op : OpOverload = get_callable (name )
53
- if self .out :
54
- # Get the out variant op
55
- op : OpOverload = to_variant (op , SchemaKind .out )
56
- return op
57
-
58
53
def report_device (self , device ):
59
54
print (f"Device: { device } \n " )
60
55
failures = []
@@ -78,7 +73,7 @@ def report_inconsistencies(self):
78
73
for device in self .devices [1 :]:
79
74
res ^= self .results [device ][meta_tuple ]
80
75
if not res :
81
- inconsistencies .append (meta_tuple )
76
+ inconsistencies .add (meta_tuple )
82
77
if len (inconsistencies ) > 0 :
83
78
print ("INCONSISTENCIES\n " )
84
79
for meta_tuple in inconsistencies :
@@ -98,26 +93,26 @@ def run(self):
98
93
def move_to_device (
99
94
self ,
100
95
device : str ,
101
- cpu_posargs : List [Any ],
102
- cpu_inkwargs : OrderedDict [str , Any ],
103
- cpu_outargs : OrderedDict [str , Any ],
96
+ src_posargs : List [Any ],
97
+ src_inkwargs : OrderedDict [str , Any ],
98
+ src_outargs : OrderedDict [str , Any ],
104
99
):
105
- if device == "cpu" :
106
- return cpu_posargs , cpu_inkwargs , cpu_outargs
100
+ if device == ( "cpu" if self . config is None else self . config . device ) :
101
+ return src_posargs , src_inkwargs , src_outargs
107
102
posargs = []
108
103
inkwargs = OrderedDict ()
109
104
outargs = OrderedDict ()
110
- for arg in cpu_posargs :
105
+ for arg in src_posargs :
111
106
new = arg
112
107
if isinstance (arg , torch .Tensor ):
113
108
new = arg .to (device = device )
114
109
posargs .append (new )
115
- for k , v in cpu_inkwargs .items ():
110
+ for k , v in src_inkwargs .items ():
116
111
new = v
117
112
if isinstance (v , torch .Tensor ):
118
113
new = v .to (device = device )
119
114
inkwargs [k ] = new
120
- for k , v in cpu_outargs .items ():
115
+ for k , v in src_outargs .items ():
121
116
new = v
122
117
if isinstance (v , torch .Tensor ):
123
118
new = v .to (device = device )
@@ -133,9 +128,30 @@ def run_meta_tuple(
133
128
posargs , inkwargs , outargs = self .move_to_device (
134
129
device , posargs , inkwargs , outargs
135
130
)
136
- success , res , posargs , inkwargs , outargs = self .run_values (
131
+ success , res , res_posargs , res_inkwargs , res_outargs = self .run_values (
137
132
meta_tuple , posargs , inkwargs , outargs
138
133
)
134
+ if (
135
+ self .valid
136
+ and success
137
+ and device != "cpu"
138
+ and isinstance (res , torch .Tensor )
139
+ ):
140
+ cpu_posargs , cpu_inkwargs , cpu_outargs = self .move_to_device (
141
+ "cpu" , posargs , inkwargs , outargs
142
+ )
143
+ (
144
+ cpu_success ,
145
+ cpu_res ,
146
+ cpu_res_posargs ,
147
+ cpu_res_inkwargs ,
148
+ cpu_res_outargs ,
149
+ ) = self .run_values (meta_tuple , cpu_posargs , cpu_inkwargs , cpu_outargs )
150
+ if cpu_success and cpu_res is not None :
151
+ if not torch .allclose (cpu_res , res .to ("cpu" )):
152
+ logging .warning (
153
+ f"NOT ALL CLOSE opname: { self .op_name } , meta_tuple: { smt (meta_tuple )} , device: { device } , { (cpu_res .to (torch .float ) - res .to ('cpu' ).to (torch .float )).abs ().max ()} "
154
+ )
139
155
mt = smt (meta_tuple )
140
156
if mt in self .results [device ]:
141
157
logging .warning (f"Repeated meta_tuple { mt } " )
@@ -183,7 +199,10 @@ def main():
183
199
raise RuntimeError (f"Op { args .op } not found in SpecDB" )
184
200
185
201
spec = SpecDictDB [args .op ]
186
- SpecRunner (spec , valid = not args .invalid , out = args .out , devices = args .devices ).run ()
202
+ op = get_op_overload (args .op )
203
+ SpecRunner (
204
+ op , spec , valid = not args .invalid , out = args .out , devices = args .devices
205
+ ).run ()
187
206
188
207
189
208
if __name__ == "__main__" :
0 commit comments