1010import functools
1111import traceback
1212import unittest
13- from typing import Any , Callable , List , OrderedDict , Sequence , Tuple
13+ from typing import Any , Callable , Sequence
1414
1515import torch
1616from executorch .backends .apple .coreml .test .tester import CoreMLTester
1717from executorch .backends .test .harness .tester import Tester as TesterBase
18- from executorch .backends .xnnpack .test .tester .tester import (
19- Tester as XnnpackTester ,
20- ToEdgeTransformAndLower ,
21- )
18+ from executorch .backends .xnnpack .test .tester .tester import Tester as XnnpackTester
2219from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
23- from facto .inputgen .specs .model import Constraint , ConstraintProducer as cp , Spec
20+ from facto .inputgen .specs .model import ConstraintProducer as cp , Spec
2421from facto .inputgen .utils .random_manager import random_manager
25- from facto .inputgen .variable .type import ScalarDtype
2622from facto .specdb .db import SpecDictDB
2723from torch ._ops import OpOverload
2824
3127CombinedSpecDB = SpecDictDB | ExtraSpecDB
3228
3329COMMON_TENSOR_CONSTRAINTS = [
34- cp .Rank .Ge (lambda deps : 1 ),
30+ cp .Rank .Ge (lambda deps : 1 ), # Avoid zero and high rank tensors.
3531 cp .Rank .Le (lambda deps : 4 ),
36- cp .Size .Ge (lambda deps , r , d : 1 ),
32+ cp .Size .Ge (lambda deps , r , d : 1 ), # Keep sizes reasonable.
3733 cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
3834]
3935
@@ -143,7 +139,9 @@ def _generate_test(op_name: str) -> None:
143139 torch_op = functools .reduce (getattr , sections , torch .ops .aten )
144140
145141 test_name = "test_" + op_name .replace ("." , "_" )
146- test_body = lambda self : self ._test_op (torch_op )
142+
143+ def test_body (self ):
144+ self ._test_op (torch_op )
147145
148146 setattr (FactoTestsBase , test_name , test_body )
149147
@@ -171,7 +169,7 @@ def get_runtime_input_count(spec: Spec):
171169 def setUp (self ):
172170 torch .set_printoptions (threshold = 3 )
173171
174- def _test_op (self , op : OpOverload ) -> None :
172+ def _test_op (self , op : OpOverload ) -> None : # noqa: C901
175173 random_manager .seed (0 )
176174
177175 # Strip namespace
@@ -182,7 +180,7 @@ def _test_op(self, op: OpOverload) -> None:
182180 op_name += ".default"
183181
184182 # Find and patch op spec
185- if not op_name in CombinedSpecDB :
183+ if op_name not in CombinedSpecDB :
186184 raise ValueError (f"Operator { op_name } not found in SpecDictDB." )
187185 spec = _patch_spec (CombinedSpecDB [op_name ])
188186
@@ -223,9 +221,7 @@ def _test_op(self, op: OpOverload) -> None:
223221 self ._tester_factory (model , tuple (posargs [:runtime_input_count ]))
224222 .export ()
225223 .dump_artifact ()
226- # .to_edge_transform_and_lower(ToEdgeTransformAndLower(partitioners=[]))
227224 .to_edge_transform_and_lower ()
228- # .dump_artifact()
229225 )
230226
231227 is_delegated = any (
@@ -248,7 +244,8 @@ def _test_op(self, op: OpOverload) -> None:
248244 success_count_undelegated += 1
249245 except Exception as e :
250246 fail_count += 1
251- print (f"Args:" )
247+ print (f"Error: { e } " )
248+ print ("Args:" )
252249 for arg in posargs :
253250 if isinstance (arg , torch .Tensor ):
254251 print (f" { arg .dtype } { arg .shape } " )
0 commit comments