9
9
import sys
10
10
from typing import Any , List , OrderedDict , Tuple
11
11
12
+ import torch
12
13
from executorch .exir .dialects .edge .op .api import get_callable , to_variant
13
14
from inputgen .argtuple .engine import MetaArgTupleEngine
14
15
from inputgen .argtuple .gen import ArgumentTupleGenerator
@@ -26,13 +27,25 @@ def smt(meta_tuple):
26
27
27
28
28
29
class SpecRunner :
29
- def __init__ (self , spec : Spec , * , valid : bool = True , out : bool = False ):
30
+ def __init__ (
31
+ self ,
32
+ spec : Spec ,
33
+ * ,
34
+ valid : bool = True ,
35
+ out : bool = False ,
36
+ devices : Tuple [str ] = ("cpu" ,),
37
+ ):
30
38
self .spec = spec
31
39
self .generator = ArgumentTupleGenerator (self .spec )
32
40
self .valid = valid
33
41
self .out = out
34
42
self .op_name = spec .op
35
43
self .op = self .get_callable_op ()
44
+ self .results = {}
45
+ self .devices = devices
46
+ self .results = {}
47
+ for device in self .devices :
48
+ self .results [device ] = {}
36
49
37
50
def get_callable_op (self ):
38
51
name = self .spec .op
@@ -42,26 +55,93 @@ def get_callable_op(self):
42
55
op : OpOverload = to_variant (op , SchemaKind .out )
43
56
return op
44
57
45
- def run (self ):
58
+ def report_device (self , device ):
59
+ print (f"Device: { device } \n " )
46
60
failures = []
47
- engine = MetaArgTupleEngine (self .spec , out = self .out )
48
- for meta_tuple in engine .gen (valid = self .valid ):
49
- success , _ , _ , _ , _ = self .run_meta_tuple (meta_tuple )
61
+ for meta_tuple in self .results [device ]:
62
+ success = self .results [device ][meta_tuple ]
50
63
if not success :
51
64
failures .append (meta_tuple )
52
65
if len (failures ) > 0 :
53
66
print ("FAILURES\n " )
54
67
for meta_tuple in failures :
55
- print (f"\t { smt ( meta_tuple ) } \n " )
68
+ print (f"\t { meta_tuple } \n " )
56
69
else :
57
70
print ("SUCCESS\n " )
58
71
72
+ def report_inconsistencies (self ):
73
+ print (f"Devices: { ' ' .join (self .devices )} \n " )
74
+ meta_tuples = self .results [self .devices [0 ]].keys ()
75
+ inconsistencies = set ()
76
+ for meta_tuple in meta_tuples :
77
+ res = self .results [self .devices [0 ]][meta_tuple ]
78
+ for device in self .devices [1 :]:
79
+ res ^= self .results [device ][meta_tuple ]
80
+ if not res :
81
+ inconsistencies .append (meta_tuple )
82
+ if len (inconsistencies ) > 0 :
83
+ print ("INCONSISTENCIES\n " )
84
+ for meta_tuple in inconsistencies :
85
+ res = [self .results [d ][meta_tuple ] for d in self .devices ]
86
+ res_string = " " .join (["x" if r else "o" for r in res ])
87
+ print (f"\t { res_string } { meta_tuple } \n " )
88
+
89
+ def run (self ):
90
+ engine = MetaArgTupleEngine (self .spec , out = self .out )
91
+ for meta_tuple in engine .gen (valid = self .valid ):
92
+ self .run_meta_tuple (meta_tuple )
93
+ if len (self .devices ) > 1 :
94
+ self .report_inconsistencies ()
95
+ for device in self .devices :
96
+ self .report_device (device )
97
+
98
+ def move_to_device (
99
+ self ,
100
+ device : str ,
101
+ cpu_posargs : List [Any ],
102
+ cpu_inkwargs : OrderedDict [str , Any ],
103
+ cpu_outargs : OrderedDict [str , Any ],
104
+ ):
105
+ if device == "cpu" :
106
+ return cpu_posargs , cpu_inkwargs , cpu_outargs
107
+ posargs = []
108
+ inkwargs = OrderedDict ()
109
+ outargs = OrderedDict ()
110
+ for arg in cpu_posargs :
111
+ new = arg
112
+ if isinstance (arg , torch .Tensor ):
113
+ new = arg .to (device = device )
114
+ posargs .append (new )
115
+ for k , v in cpu_inkwargs .items ():
116
+ new = v
117
+ if isinstance (v , torch .Tensor ):
118
+ new = v .to (device = device )
119
+ inkwargs [k ] = new
120
+ for k , v in cpu_outargs .items ():
121
+ new = v
122
+ if isinstance (v , torch .Tensor ):
123
+ new = v .to (device = device )
124
+ outargs [k ] = new
125
+ return posargs , inkwargs , outargs
126
+
59
127
def run_meta_tuple (
60
128
self , meta_tuple : Tuple [MetaArg ]
61
129
) -> Tuple [bool , Any , List [Any ], OrderedDict [str , Any ], OrderedDict [str , Any ]]:
62
130
print (f"Running op: { self .op_name } , meta_tuple: { [str (x ) for x in meta_tuple ]} " )
63
131
posargs , inkwargs , outargs = self .generator .gen_tuple (meta_tuple , out = self .out )
64
- return self .run_values (meta_tuple , posargs , inkwargs , outargs )
132
+ for device in self .devices :
133
+ posargs , inkwargs , outargs = self .move_to_device (
134
+ device , posargs , inkwargs , outargs
135
+ )
136
+ success , res , posargs , inkwargs , outargs = self .run_values (
137
+ meta_tuple , posargs , inkwargs , outargs
138
+ )
139
+ mt = smt (meta_tuple )
140
+ if mt in self .results [device ]:
141
+ logging .warning (f"Repeated meta_tuple { mt } " )
142
+ self .results [device ][mt ] &= success
143
+ else :
144
+ self .results [device ][mt ] = success
65
145
66
146
def run_values (
67
147
self ,
@@ -96,13 +176,14 @@ def main():
96
176
"--invalid" , action = "store_true" , help = "generate invalid inputs"
97
177
)
98
178
parser .add_argument ("--out" , action = "store_true" , help = "run out variants" )
179
+ parser .add_argument ("--devices" , nargs = "*" , default = ("cpu" ,), help = "run on devices" )
99
180
args = parser .parse_args ()
100
181
101
182
if args .op not in SpecDictDB :
102
183
raise RuntimeError (f"Op { args .op } not found in SpecDB" )
103
184
104
185
spec = SpecDictDB [args .op ]
105
- SpecRunner (spec , valid = not args .invalid , out = args .out ).run ()
186
+ SpecRunner (spec , valid = not args .invalid , out = args .out , devices = args . devices ).run ()
106
187
107
188
108
189
if __name__ == "__main__" :
0 commit comments