77import argparse
88import logging
99import sys
10- from typing import Any , List , OrderedDict , Tuple
10+ from typing import Any , List , Optional , OrderedDict , Tuple
1111
1212import torch
13- from executorch .exir .dialects .edge .op .api import get_callable , to_variant
1413from facto .inputgen .argtuple .engine import MetaArgTupleEngine
1514from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
1615from facto .inputgen .argument .engine import MetaArg
1716from facto .inputgen .specs .model import Spec
17+ from facto .inputgen .utils .config import TensorConfig
1818from facto .specdb .db import SpecDictDB
19+ from facto .utils .ops import get_op_overload
1920from torch ._ops import OpOverload
20- from torchgen .model import SchemaKind
2121
2222logging .basicConfig (stream = sys .stderr , level = logging .WARNING )
2323
@@ -29,32 +29,27 @@ def smt(meta_tuple):
2929class SpecRunner :
3030 def __init__ (
3131 self ,
32+ op : OpOverload ,
3233 spec : Spec ,
3334 * ,
3435 valid : bool = True ,
3536 out : bool = False ,
3637 devices : Tuple [str ] = ("cpu" ,),
38+ config : Optional [TensorConfig ] = None ,
3739 ):
3840 self .spec = spec
39- self .generator = ArgumentTupleGenerator (self .spec )
41+ self .config = config
42+ self .generator = ArgumentTupleGenerator (self .spec , config = config )
4043 self .valid = valid
4144 self .out = out
42- self .op_name = spec . op
43- self .op = self . get_callable_op ()
45+ self .op_name = op . __name__
46+ self .op = op
4447 self .results = {}
4548 self .devices = devices
4649 self .results = {}
4750 for device in self .devices :
4851 self .results [device ] = {}
4952
50- def get_callable_op (self ):
51- name = self .spec .op
52- op : OpOverload = get_callable (name )
53- if self .out :
54- # Get the out variant op
55- op : OpOverload = to_variant (op , SchemaKind .out )
56- return op
57-
5853 def report_device (self , device ):
5954 print (f"Device: { device } \n " )
6055 failures = []
@@ -78,7 +73,7 @@ def report_inconsistencies(self):
7873 for device in self .devices [1 :]:
7974 res ^= self .results [device ][meta_tuple ]
8075 if not res :
81- inconsistencies .append (meta_tuple )
76+ inconsistencies .add (meta_tuple )
8277 if len (inconsistencies ) > 0 :
8378 print ("INCONSISTENCIES\n " )
8479 for meta_tuple in inconsistencies :
@@ -98,26 +93,26 @@ def run(self):
9893 def move_to_device (
9994 self ,
10095 device : str ,
101- cpu_posargs : List [Any ],
102- cpu_inkwargs : OrderedDict [str , Any ],
103- cpu_outargs : OrderedDict [str , Any ],
96+ src_posargs : List [Any ],
97+ src_inkwargs : OrderedDict [str , Any ],
98+ src_outargs : OrderedDict [str , Any ],
10499 ):
105- if device == "cpu" :
106- return cpu_posargs , cpu_inkwargs , cpu_outargs
100+ if device == ( "cpu" if self . config is None else self . config . device ) :
101+ return src_posargs , src_inkwargs , src_outargs
107102 posargs = []
108103 inkwargs = OrderedDict ()
109104 outargs = OrderedDict ()
110- for arg in cpu_posargs :
105+ for arg in src_posargs :
111106 new = arg
112107 if isinstance (arg , torch .Tensor ):
113108 new = arg .to (device = device )
114109 posargs .append (new )
115- for k , v in cpu_inkwargs .items ():
110+ for k , v in src_inkwargs .items ():
116111 new = v
117112 if isinstance (v , torch .Tensor ):
118113 new = v .to (device = device )
119114 inkwargs [k ] = new
120- for k , v in cpu_outargs .items ():
115+ for k , v in src_outargs .items ():
121116 new = v
122117 if isinstance (v , torch .Tensor ):
123118 new = v .to (device = device )
@@ -133,9 +128,30 @@ def run_meta_tuple(
133128 posargs , inkwargs , outargs = self .move_to_device (
134129 device , posargs , inkwargs , outargs
135130 )
136- success , res , posargs , inkwargs , outargs = self .run_values (
131+ success , res , res_posargs , res_inkwargs , res_outargs = self .run_values (
137132 meta_tuple , posargs , inkwargs , outargs
138133 )
134+ if (
135+ self .valid
136+ and success
137+ and device != "cpu"
138+ and isinstance (res , torch .Tensor )
139+ ):
140+ cpu_posargs , cpu_inkwargs , cpu_outargs = self .move_to_device (
141+ "cpu" , posargs , inkwargs , outargs
142+ )
143+ (
144+ cpu_success ,
145+ cpu_res ,
146+ cpu_res_posargs ,
147+ cpu_res_inkwargs ,
148+ cpu_res_outargs ,
149+ ) = self .run_values (meta_tuple , cpu_posargs , cpu_inkwargs , cpu_outargs )
150+ if cpu_success and cpu_res is not None :
151+ if not torch .allclose (cpu_res , res .to ("cpu" )):
152+ logging .warning (
153+ f"NOT ALL CLOSE opname: { self .op_name } , meta_tuple: { smt (meta_tuple )} , device: { device } , { (cpu_res .to (torch .float ) - res .to ('cpu' ).to (torch .float )).abs ().max ()} "
154+ )
139155 mt = smt (meta_tuple )
140156 if mt in self .results [device ]:
141157 logging .warning (f"Repeated meta_tuple { mt } " )
@@ -183,7 +199,10 @@ def main():
183199 raise RuntimeError (f"Op { args .op } not found in SpecDB" )
184200
185201 spec = SpecDictDB [args .op ]
186- SpecRunner (spec , valid = not args .invalid , out = args .out , devices = args .devices ).run ()
202+ op = get_op_overload (args .op )
203+ SpecRunner (
204+ op , spec , valid = not args .invalid , out = args .out , devices = args .devices
205+ ).run ()
187206
188207
189208if __name__ == "__main__" :
0 commit comments