55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8- try :
9- from typing import List
10-
11- print ("Dumping env: " , file = sys .stderr )
12- import os
13- for name , value in os .environ .items ():
14- print (f" { name } : { value } " )
15-
16- import sys
17- print ("Trying to import ET..." , file = sys .stderr )
18- import executorch
19- print (f"ET path: { executorch .__path__ } " , file = sys .stderr )
20-
21- import torch
22- from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
23- from executorch .examples .models import Backend , Model , MODEL_NAME_TO_MODEL
24- from executorch .examples .models .model_factory import EagerModelFactory
25- from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS
26- from executorch .examples .xnnpack .quantization .utils import quantize as quantize_xnn
27- from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
28- from executorch .extension .pybindings .portable_lib import (
29- _load_for_executorch_from_buffer ,
8+ from typing import List
9+
10+ import sys
11+ import executorch
12+
13+ import torch
14+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
15+ from executorch .examples .models import Backend , Model , MODEL_NAME_TO_MODEL
16+ from executorch .examples .models .model_factory import EagerModelFactory
17+ from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS
18+ from executorch .examples .xnnpack .quantization .utils import quantize as quantize_xnn
19+ from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
20+ from executorch .extension .pybindings .portable_lib import (
21+ _load_for_executorch_from_buffer ,
22+ )
23+ from test_base import ModelTest
24+
25+
26+ def test_model_xnnpack (model : Model , quantize : bool ) -> None :
27+ model_instance , example_inputs , _ , _ = EagerModelFactory .create_model (
28+ * MODEL_NAME_TO_MODEL [str (model )]
3029 )
31- from test_base import ModelTest
3230
31+ model_instance .eval ()
32+ ref_outputs = model_instance (* example_inputs )
3333
34- def test_model_xnnpack (model : Model , quantize : bool ) -> None :
35- model_instance , example_inputs , _ , _ = EagerModelFactory .create_model (
36- * MODEL_NAME_TO_MODEL [str (model )]
34+ if quantize :
35+ quant_type = MODEL_NAME_TO_OPTIONS [str (model )].quantization
36+ model_instance = torch .export .export_for_training (
37+ model_instance , example_inputs
38+ )
39+ model_instance = quantize_xnn (
40+ model_instance .module (), example_inputs , quant_type
3741 )
3842
39- model_instance .eval ()
40- ref_outputs = model_instance (* example_inputs )
41-
42- if quantize :
43- quant_type = MODEL_NAME_TO_OPTIONS [str (model )].quantization
44- model_instance = torch .export .export_for_training (
45- model_instance , example_inputs
46- )
47- model_instance = quantize_xnn (
48- model_instance .module (), example_inputs , quant_type
49- )
50-
51- lowered = to_edge_transform_and_lower (
52- torch .export .export (model_instance , example_inputs ),
53- partitioner = [XnnpackPartitioner ()],
54- compile_config = EdgeCompileConfig (
55- _check_ir_validity = False ,
56- ),
57- ).to_executorch ()
43+ lowered = to_edge_transform_and_lower (
44+ torch .export .export (model_instance , example_inputs ),
45+ partitioner = [XnnpackPartitioner ()],
46+ compile_config = EdgeCompileConfig (
47+ _check_ir_validity = False ,
48+ ),
49+ ).to_executorch ()
5850
59- loaded_model = _load_for_executorch_from_buffer (lowered .buffer )
60- et_outputs = loaded_model ([* example_inputs ])
51+ loaded_model = _load_for_executorch_from_buffer (lowered .buffer )
52+ et_outputs = loaded_model ([* example_inputs ])
6153
62- if isinstance (ref_outputs , torch .Tensor ):
63- ref_outputs = (ref_outputs ,)
54+ if isinstance (ref_outputs , torch .Tensor ):
55+ ref_outputs = (ref_outputs ,)
6456
65- assert len (ref_outputs ) == len (et_outputs )
66- for i in range (len (ref_outputs )):
67- assert torch .allclose (ref_outputs [i ], et_outputs [i ], atol = 1e-5 )
57+ assert len (ref_outputs ) == len (et_outputs )
58+ for i in range (len (ref_outputs )):
59+ assert torch .allclose (ref_outputs [i ], et_outputs [i ], atol = 1e-5 )
6860
6961
70- def run_tests (model_tests : List [ModelTest ]) -> None :
71- for model_test in model_tests :
72- if model_test .backend == Backend .Xnnpack :
73- test_model_xnnpack (model_test .model , quantize = False )
74- else :
75- raise RuntimeError (f"Unsupported backend { model_test .backend } ." )
62+ def run_tests (model_tests : List [ModelTest ]) -> None :
63+ for model_test in model_tests :
64+ if model_test .backend == Backend .Xnnpack :
65+ test_model_xnnpack (model_test .model , quantize = False )
66+ else :
67+ raise RuntimeError (f"Unsupported backend { model_test .backend } ." )
7668
7769
78- if __name__ == "__main__" :
79- run_tests (
80- model_tests = [
81- ModelTest (
82- model = Model .Mv3 ,
83- backend = Backend .Xnnpack ,
84- ),
85- ]
86- )
87- except :
88- pass
70+ if __name__ == "__main__" :
71+ run_tests (
72+ model_tests = [
73+ ModelTest (
74+ model = Model .Mv3 ,
75+ backend = Backend .Xnnpack ,
76+ ),
77+ ]
78+ )
0 commit comments