73
73
SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads" , "eagle_module" , "drafter" ]
74
74
75
75
76
+ def _create_fake_vl_inputs (model , fake_input_ids ):
77
+ """Create fake vision-language model inputs for export process.
78
+
79
+ Args:
80
+ model: The VL model
81
+ fake_input_ids: The fake text input IDs tensor
82
+
83
+ Returns:
84
+ dict: Dictionary of fake inputs for the VL model
85
+ """
86
+ device = fake_input_ids .device
87
+ batch_size = fake_input_ids .shape [0 ]
88
+
89
+ # Create fake inputs based on common VL model patterns
90
+ fake_inputs = {
91
+ "input_ids" : fake_input_ids ,
92
+ "attention_mask" : torch .ones_like (fake_input_ids ),
93
+ }
94
+
95
+ # Add vision-specific inputs based on model configuration
96
+ if hasattr (model .config , "vision_config" ):
97
+ vision_config = model .config .vision_config
98
+ # Create fake pixel values based on vision config
99
+ if hasattr (vision_config , "image_size" ):
100
+ image_size = vision_config .image_size
101
+ else :
102
+ image_size = 224 # Default size
103
+
104
+ if hasattr (vision_config , "num_channels" ):
105
+ num_channels = vision_config .num_channels
106
+ else :
107
+ num_channels = 3 # RGB default
108
+
109
+ # Create fake pixel values
110
+ fake_inputs ["pixel_values" ] = torch .zeros (
111
+ [batch_size , num_channels , image_size , image_size ], dtype = torch .float32 , device = device
112
+ )
113
+
114
+ # Handle Nemotron-specific inputs
115
+ model_name = getattr (model , "name_or_path" , "" ).lower ()
116
+ if "nemotron" in model_name :
117
+ # Nemotron models may need specific image flags
118
+ fake_inputs ["image_flags" ] = torch .zeros ([batch_size , 1 ], dtype = torch .long , device = device )
119
+
120
+ # Some VL models need aspect ratio information
121
+ fake_inputs ["aspect_ratio_ids" ] = None
122
+ fake_inputs ["aspect_ratio_mask" ] = None
123
+ fake_inputs ["cross_attention_mask" ] = None
124
+
125
+ return fake_inputs
126
+
127
+
76
128
def _is_enabled_quantizer (quantizer ):
77
129
if hasattr (quantizer , "is_enabled" ) and quantizer .is_enabled :
78
130
return True
@@ -131,6 +183,14 @@ def _output_hook(module, input, output):
131
183
with torch .no_grad ():
132
184
fake_input = torch .ones ([1 , 2 ], dtype = torch .long ).to (model .device )
133
185
decoder_fake_input = fake_input
186
+
187
+ # Check if this is a VL model that needs special input handling
188
+ is_vl_model = (
189
+ hasattr (model .config , "vision_config" )
190
+ or hasattr (model , "vision_model" )
191
+ or "nemotron" in getattr (model , "name_or_path" , "" ).lower ()
192
+ )
193
+
134
194
if model_type .startswith ("whisper" ):
135
195
# For Whisper models, we need to pass a fake input with the specific sequence length
136
196
from transformers import AutoFeatureExtractor
@@ -139,13 +199,28 @@ def _output_hook(module, input, output):
139
199
fake_input = torch .ones (
140
200
[1 , model .config .num_mel_bins , feature_extractor .nb_max_frames ], dtype = model .dtype
141
201
).to (model .device )
202
+ elif is_vl_model :
203
+ # For VL models, create proper fake vision inputs
204
+ print ("Detected VL model during export - creating fake vision inputs" )
205
+ try :
206
+ # Try to create proper fake vision inputs for the VL model
207
+ fake_kwargs = _create_fake_vl_inputs (model , fake_input )
208
+ except Exception as e :
209
+ print (f"Failed to create fake VL inputs: { e } " )
210
+ print ("Skipping requantize_resmooth_fused_llm_layers for VL model" )
211
+ for handle in handles :
212
+ handle .remove ()
213
+ return
142
214
143
215
# Run forward pass so that all modules sharing the same input are collected using forward hook.
144
216
145
217
with set_quantizer_by_cfg_context (model , {"*" : {"enable" : False }}):
146
218
if getattr (model .config , "is_encoder_decoder" , False ):
147
219
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
148
220
model (fake_input , decoder_input_ids = decoder_fake_input )
221
+ elif is_vl_model :
222
+ # For VL models, use the fake vision inputs
223
+ model (** fake_kwargs )
149
224
else :
150
225
model (fake_input )
151
226
0 commit comments