Skip to content

Commit 2e77105

Browse files
committed
[llama-mm] Onboard Llama3.2 mm vision encoder
Summary: Add llama3.2 mm vision encoder to examples/models. We need to do a module swapping for TilePositionEmbedding to make sure vision encoder is exportable. Test Plan: Unit tests. Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent df66f00 commit 2e77105

File tree

6 files changed

+145
-1
lines changed

6 files changed

+145
-1
lines changed

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
1919
"llama2": ("llama", "Llama2Model"),
2020
"llama": ("llama", "Llama2Model"),
21+
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
2122
"lstm": ("lstm", "LSTMModel"),
2223
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2324
"mv2": ("mobilenet_v2", "MV2Model"),
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import FlamingoVisionEncoderModel, VisionEncoderConfig
8+
9+
__all__ = [
10+
"FlamingoVisionEncoderModel",
11+
"VisionEncoderConfig",
12+
]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
import torch
10+
11+
from executorch.examples.models.model_base import EagerModelBase
12+
from executorch.extension.llm.modules._position_embeddings import (
13+
replace_tile_positional_embedding,
14+
)
15+
from torchtune.models.flamingo._component_builders import flamingo_vision_encoder
16+
17+
max_seq_len = 8192
18+
in_channels = 3
19+
tile_size = 560
20+
max_num_tiles = 4
21+
# how many tokens per image generated by the vision encoder
22+
tokens_per_image = 6404
23+
# how many images to cache in the kv cache in cross attention
24+
kv_cache_image_num = 1
25+
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
26+
encoder_max_seq_len = tokens_per_image * kv_cache_image_num
27+
28+
29+
@dataclass
30+
class VisionEncoderConfig:
31+
patch_size: int = 14
32+
num_heads: int = 16
33+
clip_embed_dim: int = 1280
34+
clip_num_layers: int = 32
35+
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
36+
decoder_embed_dim: int = 4096
37+
num_layers_projection: int = 8
38+
tile_size: int = 560
39+
max_num_tiles: int = 4
40+
in_channels: int = 3
41+
42+
43+
class FlamingoVisionEncoderModel(EagerModelBase):
44+
def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()):
45+
super().__init__()
46+
self.config = config
47+
self.model = flamingo_vision_encoder(
48+
patch_size=config.patch_size,
49+
num_heads=config.num_heads,
50+
clip_embed_dim=config.clip_embed_dim,
51+
clip_num_layers=config.clip_num_layers,
52+
clip_hidden_states=config.clip_hidden_states,
53+
decoder_embed_dim=config.decoder_embed_dim,
54+
num_layers_projection=config.num_layers_projection,
55+
tile_size=config.tile_size,
56+
max_num_tiles=config.max_num_tiles,
57+
in_channels=config.in_channels,
58+
)
59+
self.image = torch.randn(
60+
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
61+
)
62+
self.aspect_ratio = torch.tensor([[[1, 2]]])
63+
self.sample_inputs = (
64+
self.image,
65+
self.aspect_ratio,
66+
)
67+
68+
def get_eager_model(self, **kwargs):
69+
self.model = replace_tile_positional_embedding(self.model)
70+
return self.model
71+
72+
def get_example_inputs(self):
73+
return self.sample_inputs
74+
75+
def get_dynamic_shapes(self):
76+
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
77+
image_dynamic_dim = {
78+
0: 1,
79+
1: 1,
80+
2: dim,
81+
3: 3,
82+
4: self.config.tile_size,
83+
5: self.config.tile_size,
84+
}
85+
return (image_dynamic_dim, None)

examples/models/llama3_2_vision/vision_encoder/test/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
8+
# Only test AOTI in this file
9+
import os
10+
import tempfile
11+
import unittest
12+
13+
import torch
14+
15+
from executorch.examples.models.llama3_2_vision.vision_encoder import (
16+
FlamingoVisionEncoderModel,
17+
VisionEncoderConfig,
18+
)
19+
from torch._inductor.package import load_package, package_aoti
20+
21+
22+
class FlamingoVisionEncoderTest(unittest.TestCase):
23+
def setUp(self) -> None:
24+
super().setUp()
25+
26+
def test_flamingo_vision_encoder(self) -> None:
27+
model = FlamingoVisionEncoderModel(VisionEncoderConfig())
28+
encoder = model.model
29+
eager_res = encoder.forward(*model.get_example_inputs())
30+
31+
# AOTI
32+
so = torch._export.aot_compile(
33+
encoder,
34+
model.get_example_inputs(),
35+
options={"aot_inductor.package": True},
36+
dynamic_shapes=model.get_dynamic_shapes(),
37+
)
38+
with tempfile.TemporaryDirectory() as tmpdir:
39+
path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so)
40+
print(path)
41+
encoder_aoti = load_package(path)
42+
43+
y = encoder_aoti(*model.get_example_inputs())
44+
45+
self.assertTrue(torch.allclose(y, eager_res))

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ addopts =
1616
devtools/
1717
# examples
1818
examples/models/llama/tests
19-
examples/models/llama3_2_vision/preprocess
19+
examples/models/llama3_2_vision/preprocess/test
20+
examples/models/llama3_2_vision/vision_encoder/test
2021
# examples/models/llava/test TODO: enable this
2122
# exir
2223
exir/_serialize/test

0 commit comments

Comments
 (0)