|
16 | 16 | FlamingoVisionEncoderModel, |
17 | 17 | ) |
18 | 18 | from torch.testing import assert_close |
19 | | - |
| 19 | +from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig |
| 20 | +from torch._inductor.package import package_aoti |
| 21 | +from torch.nn.attention import SDPBackend |
| 22 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 23 | +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| 24 | + get_symmetric_quantization_config, |
| 25 | + XNNPACKQuantizer, |
| 26 | +) |
| 27 | +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( |
| 28 | + DuplicateDynamicQuantChainPass |
| 29 | +) |
| 30 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 31 | + XnnpackDynamicallyQuantizedPartitioner, |
| 32 | +) |
20 | 33 |
|
21 | 34 | class FlamingoVisionEncoderTest(unittest.TestCase): |
22 | 35 | def setUp(self) -> None: |
23 | 36 | super().setUp() |
24 | 37 |
|
25 | | - def test_flamingo_vision_encoder(self) -> None: |
| 38 | + def test_flamingo_vision_encoder_et(self) -> None: |
| 39 | + with torch.no_grad(): |
| 40 | + vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False) |
| 41 | + encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs()) |
| 42 | + vision_model.source_transofrm() |
| 43 | + encoder = vision_model.model |
| 44 | + encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs()) |
| 45 | + assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs) |
| 46 | + |
| 47 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: |
| 48 | + training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes()) |
| 49 | + assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs())) |
| 50 | + |
| 51 | + dynamic_quantizer = XNNPACKQuantizer() |
| 52 | + operator_config_dynamic = get_symmetric_quantization_config( |
| 53 | + is_per_channel=True, is_dynamic=True |
| 54 | + ) |
| 55 | + dynamic_quantizer.set_global(operator_config_dynamic) |
| 56 | + prepare = prepare_pt2e(training_output.module(), dynamic_quantizer) |
| 57 | + prepare(*vision_model.get_example_inputs()) |
| 58 | + convert = convert_pt2e(prepare) |
| 59 | + DuplicateDynamicQuantChainPass()(convert) |
| 60 | + |
| 61 | + export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes()) |
| 62 | + |
| 63 | + edge = to_edge_transform_and_lower(export_output, partitioner=[ |
| 64 | + XnnpackDynamicallyQuantizedPartitioner(), |
| 65 | + ], compile_config=EdgeCompileConfig(_check_ir_validity=False)) |
| 66 | + edge.to_executorch() |
| 67 | + |
| 68 | + def test_flamingo_vision_encoder_aoti(self) -> None: |
26 | 69 | model = FlamingoVisionEncoderModel() |
27 | 70 | encoder = model.model |
28 | 71 | eager_res = encoder.forward(*model.get_example_inputs()) |
|
0 commit comments