Skip to content

Commit d2ae766

Browse files
authored
Export SmolvLM (#39614)
Export SmolVLM for ExecuTorch
1 parent c430047 commit d2ae766

File tree

5 files changed

+282
-18
lines changed

5 files changed

+282
-18
lines changed

src/transformers/integrations/executorch.py

Lines changed: 184 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,167 @@
2727
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
2828

2929

30+
# Add this to src/transformers/integrations/executorch.py
31+
32+
33+
class TorchExportableModuleForVLM:
34+
"""
35+
A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
36+
37+
This class handles the export of three main components:
38+
1. Vision encoder (processes images to visual features)
39+
2. Connector/projector (maps visual features to text embedding space)
40+
3. Text decoder (generates text from combined visual and text tokens)
41+
"""
42+
43+
def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024):
44+
"""
45+
Initialize the exportable VLM module.
46+
47+
Args:
48+
model: The VLM (e.g. SmolVLM) model instance
49+
max_batch_size: Maximum batch size. Always 1 for ExecuTorch
50+
max_cache_len: Maximum cache length for text generation
51+
"""
52+
self.model = model
53+
self.max_batch_size = max_batch_size
54+
self.max_cache_len = max_cache_len
55+
self.config = model.config
56+
57+
# Extract individual components
58+
self.vision_encoder = model.model.vision_model
59+
self.connector = model.model.connector
60+
self.text_decoder = model.model.text_model
61+
62+
# Store exported programs
63+
self.exported_vision_encoder = None
64+
self.exported_connector = None
65+
self.exported_text_decoder = None
66+
67+
def export_vision_encoder(self):
68+
"""Export the vision encoder component."""
69+
self.vision_encoder.eval()
70+
71+
# Create example input
72+
pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32)
73+
74+
# Define dynamic shapes
75+
dynamic_shapes = {
76+
"pixel_values": {
77+
2: torch.export.Dim.AUTO,
78+
3: torch.export.Dim.AUTO,
79+
}
80+
}
81+
82+
self.exported_vision_encoder = torch.export.export(
83+
self.vision_encoder,
84+
args=(pixel_values,),
85+
dynamic_shapes=dynamic_shapes,
86+
strict=False,
87+
)
88+
89+
return self.exported_vision_encoder
90+
91+
def export_connector(self):
92+
"""Export the connector component."""
93+
self.connector.eval()
94+
95+
# Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
96+
vision_hidden_size = self.config.vision_config.hidden_size
97+
image_size = self.config.vision_config.image_size
98+
patch_size = self.config.vision_config.patch_size
99+
patches_per_dim = image_size // patch_size
100+
num_patches = patches_per_dim * patches_per_dim
101+
image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32)
102+
103+
# Define dynamic shapes - static batch_size=1, dynamic num_patches
104+
dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}}
105+
106+
# Export the connector using torch.export
107+
self.exported_connector = torch.export.export(
108+
self.connector,
109+
args=(image_hidden_states,),
110+
dynamic_shapes=dynamic_shapes,
111+
strict=False,
112+
)
113+
114+
return self.exported_connector
115+
116+
def export_text_decoder(self):
117+
"""Export the text decoder component."""
118+
119+
# Create text decoder exportable wrapper
120+
self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM(
121+
model=self.text_decoder,
122+
max_batch_size=self.max_batch_size,
123+
max_cache_len=self.max_cache_len,
124+
)
125+
126+
# Use the existing text decoder exportable wrapper
127+
seq_length = 3
128+
input_ids = torch.zeros((1, seq_length), dtype=torch.long)
129+
cache_position = torch.arange(seq_length, dtype=torch.long)
130+
max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings)
131+
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1)
132+
133+
dynamic_shapes = {
134+
"input_ids": {1: seq_len_dim},
135+
"cache_position": {0: seq_len_dim},
136+
}
137+
138+
self.exported_text_decoder = self.exportable_text_decoder.export(
139+
input_ids=input_ids,
140+
cache_position=cache_position,
141+
dynamic_shapes=dynamic_shapes,
142+
strict=False,
143+
)
144+
145+
return self.exported_text_decoder
146+
147+
def export(self, **kwargs):
148+
"""Export all components of the VLM model."""
149+
self.export_vision_encoder(**kwargs)
150+
self.export_connector(**kwargs)
151+
self.export_text_decoder(**kwargs)
152+
return {
153+
"vision_encoder": self.exported_vision_encoder,
154+
"connector": self.exported_connector,
155+
"text_decoder": self.exported_text_decoder,
156+
}
157+
158+
def forward(self, pixel_values, input_ids, cache_position):
159+
"""
160+
Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
161+
162+
Args:
163+
pixel_values: Input images [1, channels, height, width] (optional)
164+
input_ids: Text token IDs [1, seq_len] (required - won't be None)
165+
cache_position: Cache positions [seq_len] (required - won't be None)
166+
167+
Returns:
168+
Output with logits for text generation
169+
"""
170+
pass
171+
172+
def generate(
173+
self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs
174+
):
175+
"""
176+
Simplified generate method with guaranteed non-null input_ids.
177+
178+
Args:
179+
pixel_values: Input images [1, channels, height, width] (optional)
180+
input_ids: Initial text tokens [1, seq_len] (required - won't be None)
181+
max_new_tokens: Maximum number of tokens to generate
182+
do_sample: Whether to use sampling or greedy decoding
183+
temperature: Temperature for sampling
184+
185+
Returns:
186+
Generated sequences
187+
"""
188+
pass
189+
190+
30191
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
31192
"""
32193
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
@@ -64,7 +225,7 @@ def __init__(
64225
logging.info(
65226
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
66227
)
67-
self.model = TorchExportableModuleWithStaticCache(model)
228+
self.model = TorchExportableModuleWithStaticCache(model, max_batch_size, max_cache_len)
68229
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
69230
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
70231
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@@ -254,7 +415,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
254415
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
255416
"""
256417

257-
def __init__(self, model: PreTrainedModel):
418+
def __init__(
419+
self,
420+
model: PreTrainedModel,
421+
max_batch_size: int = 1,
422+
max_cache_len: int = 4096,
423+
):
258424
"""
259425
Initializes the wrapper module with the pretrained model.
260426
@@ -270,9 +436,16 @@ def __init__(self, model: PreTrainedModel):
270436

271437
# Sanity checks
272438
if model.generation_config is None:
273-
raise AssertionError(
274-
"The model must have a generation config to be exported with static caching. "
275-
"Please set `generation_config`."
439+
# Use default generation config if not specified
440+
model.generation_config = GenerationConfig(
441+
use_cache=model.config.use_cache,
442+
cache_implementation="static",
443+
max_length=max_cache_len,
444+
cache_config={
445+
"batch_size": max_batch_size,
446+
"max_cache_len": max_cache_len,
447+
"device": "cpu",
448+
},
276449
)
277450

278451
if not model.generation_config.use_cache:
@@ -332,7 +505,12 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
332505
past_key_values=past_key_values,
333506
use_cache=True,
334507
)
335-
return outs.logits
508+
if hasattr(outs, "logits"):
509+
# Returned outputs is `CausalLMOutputWithPast`
510+
return outs.logits
511+
else:
512+
# Returned the `last_hidden_state` from `BaseModelOutputWithPast`
513+
return outs.last_hidden_state
336514

337515
@staticmethod
338516
def generate(

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
147147
nb_patches_h = p_attn_mask[:, 0].sum()
148148
nb_patches_w = p_attn_mask[0].sum()
149149

150-
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
151-
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
150+
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
151+
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
152+
153+
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
154+
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
152155

153156
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
154157
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
147147
nb_patches_h = p_attn_mask[:, 0].sum()
148148
nb_patches_w = p_attn_mask[0].sum()
149149

150-
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
151-
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
150+
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
151+
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
152+
153+
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
154+
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
152155

153156
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
154157
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
@@ -558,10 +561,10 @@ def forward(
558561
# The call to `_upad_input` in `_flash_attention_forward` is expensive
559562
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
560563
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
561-
if not torch.any(~patch_attention_mask):
562-
patch_attention_mask = None
563-
elif not self._use_flash_attention_2:
564+
if not self._use_flash_attention_2:
564565
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
566+
elif not torch.any(~patch_attention_mask):
567+
patch_attention_mask = None
565568

566569
encoder_outputs = self.encoder(
567570
inputs_embeds=hidden_states,

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,11 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
142142
nb_patches_h = p_attn_mask[:, 0].sum()
143143
nb_patches_w = p_attn_mask[0].sum()
144144

145-
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
146-
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
145+
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
146+
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
147+
148+
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
149+
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
147150

148151
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
149152
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
@@ -445,10 +448,10 @@ def forward(
445448
# The call to `_upad_input` in `_flash_attention_forward` is expensive
446449
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
447450
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
448-
if not torch.any(~patch_attention_mask):
449-
patch_attention_mask = None
450-
elif not self._use_flash_attention_2:
451+
if not self._use_flash_attention_2:
451452
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
453+
elif not torch.any(~patch_attention_mask):
454+
patch_attention_mask = None
452455

453456
encoder_outputs = self.encoder(
454457
inputs_embeds=hidden_states,

tests/models/smolvlm/test_modeling_smolvlm.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,80 @@ def test_integration_test_video(self):
595595

596596
expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
597597
self.assertEqual(generated_texts[0], expected_generated_text)
598+
599+
@slow
600+
def test_export_smolvlm_vision_encoder(self):
601+
from transformers import AutoConfig
602+
from transformers.integrations.executorch import TorchExportableModuleForVLM
603+
604+
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
605+
606+
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
607+
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
608+
config = AutoConfig.from_pretrained(model_id)
609+
config.text_config._flash_attn_2_enabled = False
610+
611+
# Load model and extract vision encoder
612+
model = SmolVLMForConditionalGeneration.from_pretrained(
613+
model_id,
614+
torch_dtype=torch.float32,
615+
config=config,
616+
)
617+
618+
exportable_module = TorchExportableModuleForVLM(model)
619+
exported_program = exportable_module.export_vision_encoder()
620+
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
621+
622+
@slow
623+
def test_export_smolvlm_connector(self):
624+
from transformers import AutoConfig
625+
from transformers.integrations.executorch import TorchExportableModuleForVLM
626+
627+
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
628+
629+
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
630+
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
631+
config = AutoConfig.from_pretrained(model_id)
632+
config.text_config._flash_attn_2_enabled = False
633+
634+
# Load the model and extract the connector (multi-modal projector)
635+
model = SmolVLMForConditionalGeneration.from_pretrained(
636+
model_id,
637+
torch_dtype=torch.float32,
638+
config=config,
639+
)
640+
641+
connector = model.model.connector
642+
connector.eval()
643+
644+
exportable_module = TorchExportableModuleForVLM(model)
645+
exported_program = exportable_module.export_connector()
646+
self.assertIsInstance(exported_program, torch.export.ExportedProgram)
647+
648+
@slow
649+
def test_export_smolvlm_text_decoder(self):
650+
from transformers import AutoConfig
651+
from transformers.integrations.executorch import TorchExportableModuleForVLM
652+
653+
model_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
654+
655+
# NOTE: The attention_mask is prepared internally in the vision encoder, depending on whether flash attention is used or not
656+
# For ExecuTorch, flash attention is not supported, so the way of exporting vison encoder should be compatible with text-decoder
657+
config = AutoConfig.from_pretrained(model_id)
658+
config.text_config._flash_attn_2_enabled = False
659+
config.text_config.use_cache = True
660+
config.text_config.attn_implementation = "sdpa"
661+
662+
# Load the model and extract the text decoder
663+
model = SmolVLMForConditionalGeneration.from_pretrained(
664+
model_id,
665+
torch_dtype=torch.float32,
666+
config=config,
667+
)
668+
669+
text_decoder = model.model.text_model
670+
text_decoder.eval()
671+
672+
exportable_module = TorchExportableModuleForVLM(model)
673+
exported_program = exportable_module.export_text_decoder()
674+
self.assertIsInstance(exported_program, torch.export.ExportedProgram)

0 commit comments

Comments
 (0)