Skip to content

Commit 48153b9

Browse files
committed
Add test case to export, quantize and lower vision encoder model for ET
[ghstack-poisoned]
1 parent 158779f commit 48153b9

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
replace_tiled_token_positional_embedding,
1616
)
1717
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18-
18+
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op
1919

2020
@dataclass
2121
class VisionEncoderConfig:
@@ -47,7 +47,7 @@ class VisionEncoderConfig:
4747

4848

4949
class FlamingoVisionEncoderModel(EagerModelBase):
50-
def __init__(self, config: Optional[VisionEncoderConfig] = None):
50+
def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True):
5151
super().__init__()
5252
if config is None:
5353
config = demo_config
@@ -64,8 +64,8 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
6464
max_num_tiles=config.max_num_tiles,
6565
in_channels=config.in_channels,
6666
)
67-
self.model = replace_tile_positional_embedding(self.model)
68-
self.model = replace_tiled_token_positional_embedding(self.model)
67+
if enable_source_transforms:
68+
self.source_transofrm()
6969
self.image = torch.randn(
7070
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
7171
)
@@ -75,6 +75,12 @@ def __init__(self, config: Optional[VisionEncoderConfig] = None):
7575
self.aspect_ratio,
7676
)
7777

78+
def source_transofrm(self):
79+
self.model = replace_tile_positional_embedding(self.model)
80+
self.model = replace_tiled_token_positional_embedding(self.model)
81+
self.model = replace_mha_with_inference_mha(self.model)
82+
self.model = replace_sdpa_with_custom_op(self.model)
83+
7884
def get_eager_model(self, **kwargs):
7985
return self.model
8086

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,56 @@
1616
FlamingoVisionEncoderModel,
1717
)
1818
from torch.testing import assert_close
19-
19+
from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
20+
from torch._inductor.package import package_aoti
21+
from torch.nn.attention import SDPBackend
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
XNNPACKQuantizer,
26+
)
27+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
28+
DuplicateDynamicQuantChainPass
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
31+
XnnpackDynamicallyQuantizedPartitioner,
32+
)
2033

2134
class FlamingoVisionEncoderTest(unittest.TestCase):
2235
def setUp(self) -> None:
2336
super().setUp()
2437

25-
def test_flamingo_vision_encoder(self) -> None:
38+
def test_flamingo_vision_encoder_et(self) -> None:
39+
with torch.no_grad():
40+
vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False)
41+
encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs())
42+
vision_model.source_transofrm()
43+
encoder = vision_model.model
44+
encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs())
45+
assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs)
46+
47+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
48+
training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
49+
assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs()))
50+
51+
dynamic_quantizer = XNNPACKQuantizer()
52+
operator_config_dynamic = get_symmetric_quantization_config(
53+
is_per_channel=True, is_dynamic=True
54+
)
55+
dynamic_quantizer.set_global(operator_config_dynamic)
56+
prepare = prepare_pt2e(training_output.module(), dynamic_quantizer)
57+
prepare(*vision_model.get_example_inputs())
58+
convert = convert_pt2e(prepare)
59+
DuplicateDynamicQuantChainPass()(convert)
60+
61+
export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
62+
63+
edge = to_edge_transform_and_lower(export_output, partitioner=[
64+
XnnpackDynamicallyQuantizedPartitioner(),
65+
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
66+
edge.to_executorch()
67+
68+
def test_flamingo_vision_encoder_aoti(self) -> None:
2669
model = FlamingoVisionEncoderModel()
2770
encoder = model.model
2871
eager_res = encoder.forward(*model.get_example_inputs())

0 commit comments

Comments
 (0)