7575SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads" , "eagle_module" , "drafter" ]
7676
7777
78+ def _create_fake_vl_inputs (model , fake_input_ids ):
79+ """Create fake vision-language model inputs for export process.
80+
81+ Args:
82+ model: The VL model
83+ fake_input_ids: The fake text input IDs tensor
84+
85+ Returns:
86+ dict: Dictionary of fake inputs for the VL model
87+ """
88+ device = fake_input_ids .device
89+ batch_size = fake_input_ids .shape [0 ]
90+
91+ # Create fake inputs based on common VL model patterns
92+ fake_inputs = {
93+ "input_ids" : fake_input_ids ,
94+ "attention_mask" : torch .ones_like (fake_input_ids ),
95+ }
96+
97+ # Add vision-specific inputs based on model configuration
98+ if hasattr (model .config , "vision_config" ):
99+ vision_config = model .config .vision_config
100+ # Create fake pixel values based on vision config
101+ if hasattr (vision_config , "image_size" ):
102+ image_size = vision_config .image_size
103+ else :
104+ image_size = 224 # Default size
105+
106+ if hasattr (vision_config , "num_channels" ):
107+ num_channels = vision_config .num_channels
108+ else :
109+ num_channels = 3 # RGB default
110+
111+ # Create fake pixel values
112+ fake_inputs ["pixel_values" ] = torch .zeros (
113+ [batch_size , num_channels , image_size , image_size ], dtype = torch .float32 , device = device
114+ )
115+
116+ # Handle Nemotron-specific inputs
117+ model_name = getattr (model , "name_or_path" , "" ).lower ()
118+ if "nemotron" in model_name :
119+ # Nemotron models may need specific image flags
120+ fake_inputs ["image_flags" ] = torch .zeros ([batch_size , 1 ], dtype = torch .long , device = device )
121+
122+ # Some VL models need aspect ratio information
123+ fake_inputs ["aspect_ratio_ids" ] = None
124+ fake_inputs ["aspect_ratio_mask" ] = None
125+ fake_inputs ["cross_attention_mask" ] = None
126+
127+ return fake_inputs
128+
129+
78130def _is_enabled_quantizer (quantizer ):
79131 if hasattr (quantizer , "is_enabled" ) and quantizer .is_enabled :
80132 return True
@@ -134,6 +186,14 @@ def _output_hook(module, input, output):
134186 with torch .no_grad ():
135187 fake_input = torch .ones ([1 , 2 ], dtype = torch .long ).to (model .device )
136188 decoder_fake_input = fake_input
189+
190+ # Check if this is a VL model that needs special input handling
191+ is_vl_model = (
192+ hasattr (model .config , "vision_config" )
193+ or hasattr (model , "vision_model" )
194+ or "nemotron" in getattr (model , "name_or_path" , "" ).lower ()
195+ )
196+
137197 if model_type .startswith ("whisper" ):
138198 # For Whisper models, we need to pass a fake input with the specific sequence length
139199 from transformers import AutoFeatureExtractor
@@ -142,13 +202,28 @@ def _output_hook(module, input, output):
142202 fake_input = torch .ones (
143203 [1 , model .config .num_mel_bins , feature_extractor .nb_max_frames ], dtype = model .dtype
144204 ).to (model .device )
205+ elif is_vl_model :
206+ # For VL models, create proper fake vision inputs
207+ print ("Detected VL model during export - creating fake vision inputs" )
208+ try :
209+ # Try to create proper fake vision inputs for the VL model
210+ fake_kwargs = _create_fake_vl_inputs (model , fake_input )
211+ except Exception as e :
212+ print (f"Failed to create fake VL inputs: { e } " )
213+ print ("Skipping requantize_resmooth_fused_llm_layers for VL model" )
214+ for handle in handles :
215+ handle .remove ()
216+ return
145217
146218 # Run forward pass so that all modules sharing the same input are collected using forward hook.
147219
148220 with set_quantizer_by_cfg_context (model , {"*" : {"enable" : False }}):
149221 if getattr (model .config , "is_encoder_decoder" , False ):
150222 # For encoder-decoder models, we need to pass both the encoder and decoder input ids
151223 model (fake_input , decoder_input_ids = decoder_fake_input )
224+ elif is_vl_model :
225+ # For VL models, use the fake vision inputs
226+ model (** fake_kwargs )
152227 else :
153228 model (fake_input )
154229
0 commit comments