1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ # pyre-strict
8+
9+ import unittest
10+
11+ import torch
12+ from executorch .backends .xnnpack import get_xnnpack_recipe
13+ from executorch .exir .schema import DelegateCall , Program
14+ from executorch .export import export
15+ from torch import nn
16+ from torch .testing ._internal .common_quantization import TestHelperModules
17+ from torchvision import models
18+ from torchvision .models .mobilenetv2 import MobileNet_V2_Weights
19+ from executorch .backends .xnnpack .test .tester import Tester
20+ from torchvision .models .segmentation import deeplabv3 , deeplabv3_resnet50 # @manual
21+
22+
23+ class TestXnnpackRecipes (unittest .TestCase ):
24+ def setUp (self ) -> None :
25+ torch ._dynamo .reset ()
26+ super ().setUp ()
27+
28+ def tearDown (self ) -> None :
29+ super ().tearDown ()
30+
31+ def check_fully_delegated (self , program : Program ) -> None :
32+ instructions = program .execution_plan [0 ].chains [0 ].instructions
33+ assert instructions is not None
34+ self .assertEqual (len (instructions ), 1 )
35+ self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
36+
37+ def test_basic_recipe (self ) -> None :
38+ m_eager = TestHelperModules .TwoLinearModule ().eval ()
39+ example_inputs = [(torch .randn (9 , 8 ),)]
40+ session = export (
41+ model = m_eager ,
42+ example_inputs = example_inputs ,
43+ export_recipe = get_xnnpack_recipe ("FP32_RECIPE" ),
44+ )
45+ self .assertTrue (
46+ torch .allclose (
47+ session .run_method ("forward" , example_inputs [0 ])[0 ],
48+ m_eager (* example_inputs [0 ]),
49+ atol = 1e-1 ,
50+ )
51+ )
52+ self .check_fully_delegated (session .get_executorch_program ())
53+
54+ def test_dynamic_quant_recipe (self ) -> None :
55+ with torch .no_grad ():
56+ m_eager = TestHelperModules .TwoLinearModule ().eval ()
57+ example_inputs = [(torch .randn (9 , 8 ),)]
58+ session = export (
59+ model = m_eager ,
60+ example_inputs = example_inputs ,
61+ export_recipe = get_xnnpack_recipe (
62+ "DYNAMIC_PER_CHANNEL_QUANT_RECIPE"
63+ ),
64+ )
65+ self .assertTrue (
66+ torch .allclose (
67+ session .run_method ("forward" , example_inputs [0 ])[0 ],
68+ m_eager (* example_inputs [0 ]),
69+ atol = 1e-1 ,
70+ )
71+ )
72+ self .check_fully_delegated (session .get_executorch_program ())
73+
74+ def test_static_quant_recipe (self ) -> None :
75+ with torch .no_grad ():
76+ m_eager = TestHelperModules .TwoLinearModule ().eval ()
77+ example_inputs = [(torch .randn (9 , 8 ),)]
78+ session = export (
79+ model = m_eager ,
80+ example_inputs = example_inputs ,
81+ export_recipe = get_xnnpack_recipe (
82+ "STATIC_PER_CHANNEL_QUANT_RECIPE"
83+ ),
84+ )
85+ self .assertTrue (
86+ torch .allclose (
87+ session .run_method ("forward" , example_inputs [0 ])[0 ],
88+ m_eager (* example_inputs [0 ]),
89+ atol = 1e-1 ,
90+ )
91+ )
92+ self .check_fully_delegated (session .get_executorch_program ())
93+
94+ def test_8a4w_recipe (self ) -> None :
95+ class SimpleLinearModel (nn .Module ):
96+ def __init__ (self ) -> None :
97+ super (SimpleLinearModel , self ).__init__ ()
98+ self .layer1 = nn .Linear (32 , 2 )
99+
100+ def forward (self , x ) -> torch .Tensor :
101+ x = self .layer1 (x )
102+ return x
103+
104+ model = SimpleLinearModel ()
105+ example_inputs = [(torch .randn (1 , 32 ),)]
106+ session = export (
107+ model = model ,
108+ example_inputs = example_inputs ,
109+ export_recipe = get_xnnpack_recipe (
110+ "8A4W_ACCELERATED_RECIPE" , group_size = 32
111+ ),
112+ )
113+ self .assertTrue (
114+ torch .allclose (
115+ session .run_method ("forward" , example_inputs [0 ])[0 ],
116+ model (* example_inputs [0 ]),
117+ atol = 1e-1 ,
118+ )
119+ )
120+ self .check_fully_delegated (session .get_executorch_program ())
121+
122+ def test_mv3_model (self ) -> None :
123+ mv3 = models .mobilenetv3 .mobilenet_v3_small (pretrained = True )
124+ mv3 = mv3 .eval ()
125+ model_inputs = [(torch .randn (1 , 3 , 224 , 224 ),)]
126+ self .assertTrue (hasattr (mv3 , "forward" ))
127+ dynamic_shapes = ({2 : torch .export .Dim ("height" , min = 224 , max = 455 ), 3 : torch .export .Dim ("width" , min = 224 , max = 455 )},)
128+ session = export (
129+ model = mv3 ,
130+ example_inputs = model_inputs ,
131+ dynamic_shapes = dynamic_shapes ,
132+ export_recipe = get_xnnpack_recipe (
133+ "STATIC_PER_CHANNEL_QUANT_RECIPE"
134+ ),
135+ )
136+
137+ Tester ._assert_outputs_equal (
138+ session .run_method ("forward" , model_inputs [0 ])[0 ],
139+ mv3 (* model_inputs [0 ]),
140+ atol = 1e-3 ,
141+ )
142+
143+ def test_mv2_model_with_static_quant_recipe (self ) -> None :
144+ mv2 = models .mobilenetv2 .mobilenet_v2 (weights = MobileNet_V2_Weights )
145+ mv2 = mv2 .eval ()
146+ model_inputs = [(torch .randn (1 , 3 , 224 , 224 ),)]
147+ self .assertTrue (hasattr (mv2 , "forward" ))
148+ dynamic_shapes = ({2 : torch .export .Dim ("height" , min = 224 , max = 455 ), 3 : torch .export .Dim ("width" , min = 224 , max = 455 )},)
149+ session = export (
150+ model = mv2 ,
151+ example_inputs = model_inputs ,
152+ dynamic_shapes = dynamic_shapes ,
153+ export_recipe = get_xnnpack_recipe (
154+ "STATIC_PER_CHANNEL_QUANT_RECIPE"
155+ ),
156+ )
157+
158+ Tester ._assert_outputs_equal (
159+ session .run_method ("forward" , model_inputs [0 ])[0 ],
160+ mv2 (* model_inputs [0 ]),
161+ atol = 1e-3 ,
162+ )
163+
164+ def test_dl3_with_recipe (self ) -> None :
165+ class DL3Wrapper (torch .nn .Module ):
166+ def __init__ (self ):
167+ super ().__init__ ()
168+ self .m = deeplabv3_resnet50 (
169+ weights = deeplabv3 .DeepLabV3_ResNet50_Weights .DEFAULT
170+ )
171+
172+ def forward (self , * args ):
173+ return self .m (* args )["out" ]
174+
175+ dl3 = DL3Wrapper ()
176+ dl3 = dl3 .eval ()
177+ model_inputs = [(torch .randn (1 , 3 , 224 , 224 ),)]
178+ self .assertTrue (hasattr (dl3 , "forward" ))
179+ session = export (
180+ model = dl3 ,
181+ example_inputs = model_inputs ,
182+ export_recipe = get_xnnpack_recipe (
183+ "STATIC_PER_CHANNEL_QUANT_RECIPE"
184+ ),
185+ )
186+
187+ Tester ._assert_outputs_equal (
188+ session .run_method ("forward" , model_inputs [0 ])[0 ],
189+ dl3 (* model_inputs [0 ]),
190+ atol = 1e-3 ,
191+ )
192+
0 commit comments