Skip to content

Commit bd2dd9b

Browse files
committed
Update on "Add test case to export, quantize and lower vision encoder model for ET"
Differential Revision: [D67878162](https://our.internmc.facebook.com/intern/diff/D67878162) [ghstack-poisoned]
2 parents 290e530 + 7d6f521 commit bd2dd9b

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
replace_tile_positional_embedding,
1515
replace_tiled_token_positional_embedding,
1616
)
17+
from executorch.extension.llm.modules.attention import (
18+
replace_mha_with_inference_mha,
19+
replace_sdpa_with_custom_op,
20+
)
1721
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder
18-
from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha, replace_sdpa_with_custom_op
22+
1923

2024
@dataclass
2125
class VisionEncoderConfig:
@@ -47,7 +51,11 @@ class VisionEncoderConfig:
4751

4852

4953
class FlamingoVisionEncoderModel(EagerModelBase):
50-
def __init__(self, config: Optional[VisionEncoderConfig] = None, enable_source_transforms = True):
54+
def __init__(
55+
self,
56+
config: Optional[VisionEncoderConfig] = None,
57+
enable_source_transforms=True,
58+
):
5159
super().__init__()
5260
if config is None:
5361
config = demo_config

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@
1111
import unittest
1212

1313
import torch
14+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
15+
DuplicateDynamicQuantChainPass,
16+
)
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
18+
XnnpackDynamicallyQuantizedPartitioner,
19+
)
1420

1521
from executorch.examples.models.llama3_2_vision.vision_encoder import (
1622
FlamingoVisionEncoderModel,
1723
)
18-
from torch.testing import assert_close
19-
from executorch.exir import to_edge, to_edge_transform_and_lower, EdgeCompileConfig
20-
from torch._inductor.package import package_aoti
21-
from torch.nn.attention import SDPBackend
24+
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
2225
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2326
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
2427
get_symmetric_quantization_config,
2528
XNNPACKQuantizer,
2629
)
27-
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
28-
DuplicateDynamicQuantChainPass
29-
)
30-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
31-
XnnpackDynamicallyQuantizedPartitioner,
32-
)
30+
from torch.nn.attention import SDPBackend
31+
from torch.testing import assert_close
32+
3333

3434
class FlamingoVisionEncoderTest(unittest.TestCase):
3535
def setUp(self) -> None:
@@ -38,15 +38,30 @@ def setUp(self) -> None:
3838
def test_flamingo_vision_encoder_et(self) -> None:
3939
with torch.no_grad():
4040
vision_model = FlamingoVisionEncoderModel(enable_source_transforms=False)
41-
encoder_no_source_transform_outputs = vision_model.model.forward(*vision_model.get_example_inputs())
41+
encoder_no_source_transform_outputs = vision_model.model.forward(
42+
*vision_model.get_example_inputs()
43+
)
4244
vision_model.source_transofrm()
4345
encoder = vision_model.model
44-
encoder_source_transform_outputs = encoder.forward(*vision_model.get_example_inputs())
45-
assert_close(encoder_source_transform_outputs, encoder_no_source_transform_outputs)
46+
encoder_source_transform_outputs = encoder.forward(
47+
*vision_model.get_example_inputs()
48+
)
49+
assert_close(
50+
encoder_source_transform_outputs, encoder_no_source_transform_outputs
51+
)
4652

47-
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
48-
training_output = torch.export.export_for_training(encoder, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
49-
assert_close(encoder(*vision_model.get_example_inputs()), training_output.module()(*vision_model.get_example_inputs()))
53+
with torch.nn.attention.sdpa_kernel(
54+
[SDPBackend.MATH]
55+
), torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
56+
training_output = torch.export.export_for_training(
57+
encoder,
58+
vision_model.get_example_inputs(),
59+
dynamic_shapes=vision_model.get_dynamic_shapes(),
60+
)
61+
assert_close(
62+
encoder(*vision_model.get_example_inputs()),
63+
training_output.module()(*vision_model.get_example_inputs()),
64+
)
5065

5166
dynamic_quantizer = XNNPACKQuantizer()
5267
operator_config_dynamic = get_symmetric_quantization_config(
@@ -58,11 +73,19 @@ def test_flamingo_vision_encoder_et(self) -> None:
5873
convert = convert_pt2e(prepare)
5974
DuplicateDynamicQuantChainPass()(convert)
6075

61-
export_output = torch.export.export(convert, vision_model.get_example_inputs(), dynamic_shapes=vision_model.get_dynamic_shapes())
76+
export_output = torch.export.export(
77+
convert,
78+
vision_model.get_example_inputs(),
79+
dynamic_shapes=vision_model.get_dynamic_shapes(),
80+
)
6281

63-
edge = to_edge_transform_and_lower(export_output, partitioner=[
64-
XnnpackDynamicallyQuantizedPartitioner(),
65-
], compile_config=EdgeCompileConfig(_check_ir_validity=False))
82+
edge = to_edge_transform_and_lower(
83+
export_output,
84+
partitioner=[
85+
XnnpackDynamicallyQuantizedPartitioner(),
86+
],
87+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
88+
)
6689
edge.to_executorch()
6790

6891
def test_flamingo_vision_encoder_aoti(self) -> None:

extension/llm/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import torch
1111
import torchtune.modules.attention as TorchTuneAttention
12+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1213
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
1314
from torch import nn
1415
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1516
from torchtune.modules.kv_cache import KVCache
16-
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -367,7 +367,6 @@ def forward(
367367
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
368368
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
369369

370-
371370
output = self._attention_fn(
372371
q,
373372
k,
@@ -431,5 +430,6 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
431430

432431
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
433432
from executorch.extension.llm.custom_ops import custom_ops
433+
434434
_replace_sdpa_with_custom_op(module)
435435
return module

0 commit comments

Comments
 (0)