27
27
from ..pytorch_utils import is_torch_greater_or_equal , is_torch_greater_or_equal_than_2_3
28
28
29
29
30
+ # Add this to src/transformers/integrations/executorch.py
31
+
32
+
33
+ class TorchExportableModuleForVLM :
34
+ """
35
+ A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
36
+
37
+ This class handles the export of three main components:
38
+ 1. Vision encoder (processes images to visual features)
39
+ 2. Connector/projector (maps visual features to text embedding space)
40
+ 3. Text decoder (generates text from combined visual and text tokens)
41
+ """
42
+
43
+ def __init__ (self , model , max_batch_size : int = 1 , max_cache_len : int = 1024 ):
44
+ """
45
+ Initialize the exportable VLM module.
46
+
47
+ Args:
48
+ model: The VLM (e.g. SmolVLM) model instance
49
+ max_batch_size: Maximum batch size. Always 1 for ExecuTorch
50
+ max_cache_len: Maximum cache length for text generation
51
+ """
52
+ self .model = model
53
+ self .max_batch_size = max_batch_size
54
+ self .max_cache_len = max_cache_len
55
+ self .config = model .config
56
+
57
+ # Extract individual components
58
+ self .vision_encoder = model .model .vision_model
59
+ self .connector = model .model .connector
60
+ self .text_decoder = model .model .text_model
61
+
62
+ # Store exported programs
63
+ self .exported_vision_encoder = None
64
+ self .exported_connector = None
65
+ self .exported_text_decoder = None
66
+
67
+ def export_vision_encoder (self ):
68
+ """Export the vision encoder component."""
69
+ self .vision_encoder .eval ()
70
+
71
+ # Create example input
72
+ pixel_values = torch .randn (1 , 3 , 384 , 384 , dtype = torch .float32 )
73
+
74
+ # Define dynamic shapes
75
+ dynamic_shapes = {
76
+ "pixel_values" : {
77
+ 2 : torch .export .Dim .AUTO ,
78
+ 3 : torch .export .Dim .AUTO ,
79
+ }
80
+ }
81
+
82
+ self .exported_vision_encoder = torch .export .export (
83
+ self .vision_encoder ,
84
+ args = (pixel_values ,),
85
+ dynamic_shapes = dynamic_shapes ,
86
+ strict = False ,
87
+ )
88
+
89
+ return self .exported_vision_encoder
90
+
91
+ def export_connector (self ):
92
+ """Export the connector component."""
93
+ self .connector .eval ()
94
+
95
+ # Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
96
+ vision_hidden_size = self .config .vision_config .hidden_size
97
+ image_size = self .config .vision_config .image_size
98
+ patch_size = self .config .vision_config .patch_size
99
+ patches_per_dim = image_size // patch_size
100
+ num_patches = patches_per_dim * patches_per_dim
101
+ image_hidden_states = torch .randn (1 , num_patches , vision_hidden_size , dtype = torch .float32 )
102
+
103
+ # Define dynamic shapes - static batch_size=1, dynamic num_patches
104
+ dynamic_shapes = {"image_hidden_states" : {1 : torch .export .Dim .AUTO }}
105
+
106
+ # Export the connector using torch.export
107
+ self .exported_connector = torch .export .export (
108
+ self .connector ,
109
+ args = (image_hidden_states ,),
110
+ dynamic_shapes = dynamic_shapes ,
111
+ strict = False ,
112
+ )
113
+
114
+ return self .exported_connector
115
+
116
+ def export_text_decoder (self ):
117
+ """Export the text decoder component."""
118
+
119
+ # Create text decoder exportable wrapper
120
+ self .exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM (
121
+ model = self .text_decoder ,
122
+ max_batch_size = self .max_batch_size ,
123
+ max_cache_len = self .max_cache_len ,
124
+ )
125
+
126
+ # Use the existing text decoder exportable wrapper
127
+ seq_length = 3
128
+ input_ids = torch .zeros ((1 , seq_length ), dtype = torch .long )
129
+ cache_position = torch .arange (seq_length , dtype = torch .long )
130
+ max_seq_length = min (self .max_cache_len , self .config .text_config .max_position_embeddings )
131
+ seq_len_dim = torch .export .Dim ("seq_length_dim" , max = max_seq_length - 1 )
132
+
133
+ dynamic_shapes = {
134
+ "input_ids" : {1 : seq_len_dim },
135
+ "cache_position" : {0 : seq_len_dim },
136
+ }
137
+
138
+ self .exported_text_decoder = self .exportable_text_decoder .export (
139
+ input_ids = input_ids ,
140
+ cache_position = cache_position ,
141
+ dynamic_shapes = dynamic_shapes ,
142
+ strict = False ,
143
+ )
144
+
145
+ return self .exported_text_decoder
146
+
147
+ def export (self , ** kwargs ):
148
+ """Export all components of the VLM model."""
149
+ self .export_vision_encoder (** kwargs )
150
+ self .export_connector (** kwargs )
151
+ self .export_text_decoder (** kwargs )
152
+ return {
153
+ "vision_encoder" : self .exported_vision_encoder ,
154
+ "connector" : self .exported_connector ,
155
+ "text_decoder" : self .exported_text_decoder ,
156
+ }
157
+
158
+ def forward (self , pixel_values , input_ids , cache_position ):
159
+ """
160
+ Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
161
+
162
+ Args:
163
+ pixel_values: Input images [1, channels, height, width] (optional)
164
+ input_ids: Text token IDs [1, seq_len] (required - won't be None)
165
+ cache_position: Cache positions [seq_len] (required - won't be None)
166
+
167
+ Returns:
168
+ Output with logits for text generation
169
+ """
170
+ pass
171
+
172
+ def generate (
173
+ self , pixel_values = None , input_ids = None , max_new_tokens = 50 , do_sample = False , temperature = 1.0 , ** kwargs
174
+ ):
175
+ """
176
+ Simplified generate method with guaranteed non-null input_ids.
177
+
178
+ Args:
179
+ pixel_values: Input images [1, channels, height, width] (optional)
180
+ input_ids: Initial text tokens [1, seq_len] (required - won't be None)
181
+ max_new_tokens: Maximum number of tokens to generate
182
+ do_sample: Whether to use sampling or greedy decoding
183
+ temperature: Temperature for sampling
184
+
185
+ Returns:
186
+ Generated sequences
187
+ """
188
+ pass
189
+
190
+
30
191
class TorchExportableModuleForDecoderOnlyLM (torch .nn .Module ):
31
192
"""
32
193
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
@@ -64,7 +225,7 @@ def __init__(
64
225
logging .info (
65
226
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
66
227
)
67
- self .model = TorchExportableModuleWithStaticCache (model )
228
+ self .model = TorchExportableModuleWithStaticCache (model , max_batch_size , max_cache_len )
68
229
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
69
230
ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , sdpa_mask_without_vmap )
70
231
ALL_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , ALL_ATTENTION_FUNCTIONS ["sdpa" ])
@@ -254,7 +415,12 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
254
415
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
255
416
"""
256
417
257
- def __init__ (self , model : PreTrainedModel ):
418
+ def __init__ (
419
+ self ,
420
+ model : PreTrainedModel ,
421
+ max_batch_size : int = 1 ,
422
+ max_cache_len : int = 4096 ,
423
+ ):
258
424
"""
259
425
Initializes the wrapper module with the pretrained model.
260
426
@@ -270,9 +436,16 @@ def __init__(self, model: PreTrainedModel):
270
436
271
437
# Sanity checks
272
438
if model .generation_config is None :
273
- raise AssertionError (
274
- "The model must have a generation config to be exported with static caching. "
275
- "Please set `generation_config`."
439
+ # Use default generation config if not specified
440
+ model .generation_config = GenerationConfig (
441
+ use_cache = model .config .use_cache ,
442
+ cache_implementation = "static" ,
443
+ max_length = max_cache_len ,
444
+ cache_config = {
445
+ "batch_size" : max_batch_size ,
446
+ "max_cache_len" : max_cache_len ,
447
+ "device" : "cpu" ,
448
+ },
276
449
)
277
450
278
451
if not model .generation_config .use_cache :
@@ -332,7 +505,12 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
332
505
past_key_values = past_key_values ,
333
506
use_cache = True ,
334
507
)
335
- return outs .logits
508
+ if hasattr (outs , "logits" ):
509
+ # Returned outputs is `CausalLMOutputWithPast`
510
+ return outs .logits
511
+ else :
512
+ # Returned the `last_hidden_state` from `BaseModelOutputWithPast`
513
+ return outs .last_hidden_state
336
514
337
515
@staticmethod
338
516
def generate (
0 commit comments