1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ # pyre-strict
8+
9+ import copy
10+ import functools
11+ import traceback
12+ from typing import Any , Callable , List , OrderedDict , Sequence , Tuple
13+ import unittest
14+
15+ import torch
16+ from executorch .backends .test .harness .tester import Tester as TesterBase
17+ from executorch .backends .xnnpack .test .tester .tester import ToEdgeTransformAndLower , Tester as XnnpackTester
18+ from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
19+ from facto .inputgen .specs .model import Constraint , ConstraintProducer as cp , Spec
20+ from facto .inputgen .utils .random_manager import random_manager
21+ from facto .inputgen .variable .type import ScalarDtype
22+ from facto .specdb .db import SpecDictDB
23+ from torch ._ops import OpOverload
24+
25+ from .facto_specs import ExtraSpecDB
26+
27+ CombinedSpecDB = SpecDictDB | ExtraSpecDB
28+
29+ COMMON_TENSOR_CONSTRAINTS = [
30+ cp .Rank .Ge (lambda deps : 1 ),
31+ cp .Rank .Le (lambda deps : 4 ),
32+ cp .Size .Ge (lambda deps , r , d : 1 ),
33+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
34+ ]
35+
36+ COMMON_SCALAR_CONSTRAINS = [
37+ cp .Value .Ge (lambda deps , dtype : - 1000 ),
38+ cp .Value .Le (lambda deps , dtype : 1000 ),
39+ ]
40+
41+ # Operator args are treated as runtime graph inputs if the argument name is
42+ # in this list.
43+ RUNTIME_INPUT_NAMES = {
44+ "self" ,
45+ "tensor" ,
46+ "other" ,
47+ }
48+
49+ def _patch_spec (spec : Spec ) -> Spec :
50+ spec = copy .deepcopy (spec )
51+ for inspec in spec .inspec :
52+ if inspec .type .is_tensor ():
53+ inspec .constraints .extend (COMMON_TENSOR_CONSTRAINTS )
54+ elif inspec .type .is_scalar ():
55+ inspec .constraints .extend (COMMON_SCALAR_CONSTRAINS )
56+ return spec
57+
58+ class OpModel (torch .nn .Module ):
59+ """
60+ Wraps a single torch operator in an nn.Module.
61+ """
62+ def __init__ (
63+ self ,
64+ op : OpOverload ,
65+ runtime_input_count : int ,
66+ fixed_args : Sequence [Any ],
67+ fixed_kwargs : dict [str , Any ]
68+ ):
69+ super ().__init__ ()
70+ self .op = op
71+ self .runtime_input_count = runtime_input_count
72+ self .fixed_kwargs = fixed_kwargs
73+
74+ # Register parameters for fixed tensors. Some things will choke on
75+ # constant tensor weights, for example.
76+ new_args = []
77+ for i , arg in enumerate (fixed_args ):
78+ if isinstance (arg , torch .Tensor ):
79+ param = torch .nn .Parameter (arg , requires_grad = False )
80+ param_name = f"arg_{ i } _param"
81+ setattr (self , param_name , param )
82+ self .register_parameter (param_name , param )
83+ new_args .append (param )
84+ else :
85+ new_args .append (arg )
86+ self .fixed_args = tuple (new_args )
87+
88+ def forward (self , * args , ** kwargs ):
89+ return self .op (* (args + self .fixed_args ), ** (kwargs | self .fixed_kwargs ))
90+
91+ class ConvModel (OpModel ):
92+ def forward (self , * args , ** kwargs ):
93+ weight , bias , stride , padding , dilation , transposed , output_padding , groups = self .fixed_args
94+
95+ if not transposed :
96+ if len (weight .shape ) == 3 :
97+ op = torch .nn .functional .conv1d
98+ elif len (weight .shape ) == 4 :
99+ op = torch .nn .functional .conv2d
100+ elif len (weight .shape ) == 5 :
101+ op = torch .nn .functional .conv3d
102+
103+ return op (args [0 ], weight , bias , stride , padding , dilation , groups )
104+ else :
105+ if len (weight .shape ) == 3 :
106+ op = torch .nn .functional .conv_transpose1d
107+ elif len (weight .shape ) == 4 :
108+ op = torch .nn .functional .conv_transpose2d
109+ elif len (weight .shape ) == 5 :
110+ op = torch .nn .functional .conv_transpose3d
111+
112+ return op (args [0 ], weight , bias , stride , padding , output_padding , groups , dilation )
113+
114+ def get_module_for_op (op : OpOverload ):
115+ if op == torch .ops .aten .convolution .default :
116+ return ConvModel
117+ else :
118+ return OpModel
119+
120+ class FactoTestsBase (unittest .TestCase ):
121+ def __init__ (self , tester_factory : Callable [[], TesterBase ], * args , ** kwargs ):
122+ super ().__init__ (* args , ** kwargs )
123+ self ._tester_factory = tester_factory
124+
125+ @staticmethod
126+ def _generate_test (op_name : str ) -> None :
127+ # Find the torch op with the given name.
128+ sections = op_name .split ("." )
129+ torch_op = functools .reduce (getattr , sections , torch .ops .aten )
130+
131+ test_name = "test_" + op_name .replace ("." , "_" )
132+ test_body = lambda self : self ._test_op (torch_op )
133+
134+ setattr (FactoTestsBase , test_name , test_body )
135+
136+ @staticmethod
137+ def get_runtime_input_count (spec : Spec ):
138+ # Determine which inputs are fixed at tracing time (weights, for example),
139+ # vs inputs to the runtime graph. We currently assume that the runtime graph
140+ # inputs start at the beginning of the arg list and are contiguous.
141+ #
142+ # Args are consider to be runtime inputs if they are positional and are named
143+ # one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a
144+ # runtime input.
145+ runtime_input_count = 0
146+ for inspec in spec .inspec :
147+ is_runtime_input = (
148+ inspec .type .is_tensor () and
149+ inspec .name .lower () in RUNTIME_INPUT_NAMES
150+ )
151+ if is_runtime_input :
152+ runtime_input_count += 1
153+ else :
154+ break
155+
156+ return max (1 , runtime_input_count )
157+
158+ def setUp (self ):
159+ torch .set_printoptions (threshold = 3 )
160+
161+ def _test_op (self , op : OpOverload ) -> None :
162+ random_manager .seed (0 )
163+
164+ # Strip namespace
165+ op_name = op .name ().split ("::" )[- 1 ]
166+
167+ # Default to .default overload
168+ if "." not in op_name :
169+ op_name += ".default"
170+
171+ # Find and patch op spec
172+ if not op_name in CombinedSpecDB :
173+ raise ValueError (f"Operator { op_name } not found in SpecDictDB." )
174+ spec = _patch_spec (CombinedSpecDB [op_name ])
175+
176+ runtime_input_count = FactoTestsBase .get_runtime_input_count (spec )
177+
178+ print (f"Op: { op_name } , { runtime_input_count } runtime inputs" )
179+
180+ # Run test cases
181+ success_count_delegated = 0
182+ success_count_undelegated = 0
183+ fail_count = 0
184+
185+ i = 0
186+ for posargs , inkwargs , _ in ArgumentTupleGenerator (spec ).gen ():
187+ i += 1
188+
189+ try :
190+ if isinstance (posargs [0 ], torch .Tensor ):
191+ # Temporary for getting around XNN crashes (https://github.com/pytorch/executorch/issues/10960).
192+ # TODO Re-enable when resolved.
193+ if posargs [0 ].dtype in {torch .int8 , torch .uint8 }:
194+ print ("Skipping (u)int8 case." )
195+ continue
196+
197+ module_cls = get_module_for_op (op )
198+ model = module_cls (
199+ op ,
200+ runtime_input_count ,
201+ posargs [runtime_input_count :],
202+ inkwargs
203+ )
204+
205+ # Sanity check to make sure it runs in eager. This can present nicer error
206+ # messages sometimes compared to tracing.
207+ try :
208+ model (* posargs [:runtime_input_count ])
209+ except Exception as e :
210+ print (f"Eager execution failed: { e } " )
211+ continue
212+
213+ tester = self ._tester_factory (
214+ model ,
215+ tuple (posargs [:runtime_input_count ])
216+ )
217+
218+ # Dynamo will also fail to handle some patterns that are valid in eager.
219+ try :
220+ tester .export ()
221+ except Exception as e :
222+ print (f"Export failed." )
223+ continue
224+
225+ tester .to_edge_transform_and_lower ()
226+
227+ is_delegated = any (
228+ n .target == torch ._higher_order_ops .executorch_call_delegate
229+ for n in tester .stages [tester .cur ].graph_module .graph .nodes
230+ if n .op == "call_function"
231+ )
232+
233+ # Only run the runtime test if the op was delegated.
234+ if is_delegated :
235+ (
236+ tester
237+ .to_executorch ()
238+ .serialize ()
239+ .run_method_and_compare_outputs ()
240+ )
241+
242+ if is_delegated :
243+ success_count_delegated += 1
244+ else :
245+ success_count_undelegated += 1
246+ #finally:
247+ except Exception as e :
248+ fail_count += 1
249+ print (f"Args:" )
250+ for arg in posargs :
251+ if isinstance (arg , torch .Tensor ):
252+ print (f" { arg .dtype } { arg .shape } " )
253+ else :
254+ print (f" { arg } " )
255+
256+ traceback .print_exc ()
257+
258+ print (f"{ success_count_delegated + success_count_undelegated } PASS, { fail_count } FAIL" )
259+ print (f" { success_count_delegated } DELEGATED, { success_count_undelegated } UNDELEGATED" )
260+
261+ # Programatically generate tests for each operator.
262+ for op_name in CombinedSpecDB .keys ():
263+ FactoTestsBase ._generate_test (op_name )
264+
265+ # TODO Figure out where to put these
266+ class FactoTestsXNNPACK (FactoTestsBase ):
267+ def __init__ (self , * args , ** kwargs ):
268+ super ().__init__ (XnnpackTester , * args , ** kwargs )
269+
270+ try :
271+ from executorch .backends .apple .coreml .test .tester import CoreMLTester
272+ class FactoTestsCoreML (FactoTestsBase ):
273+ def __init__ (self , * args , ** kwargs ):
274+ super ().__init__ (CoreMLTester , * args , ** kwargs )
275+ except :
276+ print ("Skipping Core ML facto tests as Core ML AOT is not available." )
0 commit comments