1111import unittest
1212
1313import torch
14+ from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
15+ DuplicateDynamicQuantChainPass ,
16+ )
17+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
18+ XnnpackDynamicallyQuantizedPartitioner ,
19+ )
1420
1521from executorch .examples .models .llama3_2_vision .vision_encoder import (
1622 FlamingoVisionEncoderModel ,
1723)
18- from torch .testing import assert_close
19- from executorch .exir import to_edge , to_edge_transform_and_lower , EdgeCompileConfig
20- from torch ._inductor .package import package_aoti
21- from torch .nn .attention import SDPBackend
24+ from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
2225from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
2326from torch .ao .quantization .quantizer .xnnpack_quantizer import (
2427 get_symmetric_quantization_config ,
2528 XNNPACKQuantizer ,
2629)
27- from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
28- DuplicateDynamicQuantChainPass
29- )
30- from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
31- XnnpackDynamicallyQuantizedPartitioner ,
32- )
30+ from torch .nn .attention import SDPBackend
31+ from torch .testing import assert_close
32+
3333
3434class FlamingoVisionEncoderTest (unittest .TestCase ):
3535 def setUp (self ) -> None :
@@ -38,15 +38,30 @@ def setUp(self) -> None:
3838 def test_flamingo_vision_encoder_et (self ) -> None :
3939 with torch .no_grad ():
4040 vision_model = FlamingoVisionEncoderModel (enable_source_transforms = False )
41- encoder_no_source_transform_outputs = vision_model .model .forward (* vision_model .get_example_inputs ())
41+ encoder_no_source_transform_outputs = vision_model .model .forward (
42+ * vision_model .get_example_inputs ()
43+ )
4244 vision_model .source_transofrm ()
4345 encoder = vision_model .model
44- encoder_source_transform_outputs = encoder .forward (* vision_model .get_example_inputs ())
45- assert_close (encoder_source_transform_outputs , encoder_no_source_transform_outputs )
46+ encoder_source_transform_outputs = encoder .forward (
47+ * vision_model .get_example_inputs ()
48+ )
49+ assert_close (
50+ encoder_source_transform_outputs , encoder_no_source_transform_outputs
51+ )
4652
47- with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad (), tempfile .TemporaryDirectory () as tmpdir :
48- training_output = torch .export .export_for_training (encoder , vision_model .get_example_inputs (), dynamic_shapes = vision_model .get_dynamic_shapes ())
49- assert_close (encoder (* vision_model .get_example_inputs ()), training_output .module ()(* vision_model .get_example_inputs ()))
53+ with torch .nn .attention .sdpa_kernel (
54+ [SDPBackend .MATH ]
55+ ), torch .no_grad (), tempfile .TemporaryDirectory () as tmpdir :
56+ training_output = torch .export .export_for_training (
57+ encoder ,
58+ vision_model .get_example_inputs (),
59+ dynamic_shapes = vision_model .get_dynamic_shapes (),
60+ )
61+ assert_close (
62+ encoder (* vision_model .get_example_inputs ()),
63+ training_output .module ()(* vision_model .get_example_inputs ()),
64+ )
5065
5166 dynamic_quantizer = XNNPACKQuantizer ()
5267 operator_config_dynamic = get_symmetric_quantization_config (
@@ -58,11 +73,19 @@ def test_flamingo_vision_encoder_et(self) -> None:
5873 convert = convert_pt2e (prepare )
5974 DuplicateDynamicQuantChainPass ()(convert )
6075
61- export_output = torch .export .export (convert , vision_model .get_example_inputs (), dynamic_shapes = vision_model .get_dynamic_shapes ())
76+ export_output = torch .export .export (
77+ convert ,
78+ vision_model .get_example_inputs (),
79+ dynamic_shapes = vision_model .get_dynamic_shapes (),
80+ )
6281
63- edge = to_edge_transform_and_lower (export_output , partitioner = [
64- XnnpackDynamicallyQuantizedPartitioner (),
65- ], compile_config = EdgeCompileConfig (_check_ir_validity = False ))
82+ edge = to_edge_transform_and_lower (
83+ export_output ,
84+ partitioner = [
85+ XnnpackDynamicallyQuantizedPartitioner (),
86+ ],
87+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
88+ )
6689 edge .to_executorch ()
6790
6891 def test_flamingo_vision_encoder_aoti (self ) -> None :
0 commit comments