@@ -85,17 +85,26 @@ def _create_fake_vl_inputs(model, fake_input_ids):
8585 Returns:
8686 dict: Dictionary of fake inputs for the VL model
8787 """
88+ import inspect
89+
8890 device = fake_input_ids .device
8991 batch_size = fake_input_ids .shape [0 ]
9092
93+ # Get the model's forward method signature to see what parameters it accepts
94+ forward_signature = inspect .signature (model .forward )
95+ accepted_params = set (forward_signature .parameters .keys ())
96+
9197 # Create fake inputs based on common VL model patterns
92- fake_inputs = {
93- "input_ids" : fake_input_ids ,
94- "attention_mask" : torch .ones_like (fake_input_ids ),
95- }
98+ fake_inputs = {}
99+
100+ # Always include basic text inputs if accepted
101+ if "input_ids" in accepted_params :
102+ fake_inputs ["input_ids" ] = fake_input_ids
103+ if "attention_mask" in accepted_params :
104+ fake_inputs ["attention_mask" ] = torch .ones_like (fake_input_ids )
96105
97- # Add vision-specific inputs based on model configuration
98- if hasattr (model .config , "vision_config" ):
106+ # Add vision-specific inputs based on model configuration and accepted parameters
107+ if hasattr (model .config , "vision_config" ) and "pixel_values" in accepted_params :
99108 vision_config = model .config .vision_config
100109 # Create fake pixel values based on vision config
101110 if hasattr (vision_config , "image_size" ):
@@ -113,16 +122,34 @@ def _create_fake_vl_inputs(model, fake_input_ids):
113122 [batch_size , num_channels , image_size , image_size ], dtype = torch .float32 , device = device
114123 )
115124
116- # Handle Nemotron-specific inputs
125+ # Handle Nemotron-specific inputs based on testing results
117126 model_name = getattr (model , "name_or_path" , "" ).lower ()
118127 if "nemotron" in model_name :
119- # Nemotron models may need specific image flags
120- fake_inputs ["image_flags" ] = torch .zeros ([batch_size , 1 ], dtype = torch .long , device = device )
128+ if "pixel_values" in accepted_params :
129+ # Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512]
130+ # This represents 14 image patches, each 512x512 pixels with 3 channels
131+ num_patches = 14
132+ patch_size = 512
133+ num_channels = 3
134+
135+ # Override any previous pixel_values with the correct Nemotron format
136+ # Use small random values instead of zeros to avoid NoneType issues
137+ fake_inputs ["pixel_values" ] = (
138+ torch .randn (
139+ [num_patches , num_channels , patch_size , patch_size ],
140+ dtype = torch .float32 ,
141+ device = device ,
142+ )
143+ * 0.1
144+ ) # Small values to avoid extreme activations
121145
122- # Some VL models need aspect ratio information
123- fake_inputs ["aspect_ratio_ids" ] = None
124- fake_inputs ["aspect_ratio_mask" ] = None
125- fake_inputs ["cross_attention_mask" ] = None
146+ if "image_flags" in accepted_params :
147+ # Based on testing, image_flags should have shape [14] (no batch dimension)
148+ # to match the [14, 256, 4096] tensor it's used to mask
149+ num_patches = 14 # From pixel_values shape [14, 3, 512, 512]
150+ fake_inputs ["image_flags" ] = torch .zeros (
151+ [num_patches ], dtype = torch .long , device = device
152+ ) # Shape [14] to match vision tensor dimensions
126153
127154 return fake_inputs
128155
@@ -205,6 +232,31 @@ def _output_hook(module, input, output):
205232 elif is_vl_model :
206233 # For VL models, create proper fake vision inputs
207234 print ("Detected VL model during export - creating fake vision inputs" )
235+
236+ # Pre-emptively initialize distributed for Nemotron models that require it
237+ model_name = getattr (model , "name_or_path" , "" ).lower ()
238+ if "nemotron" in model_name :
239+ import os
240+
241+ import torch .distributed as dist
242+
243+ if not dist .is_available () or not dist .is_initialized ():
244+ print ("Pre-initializing distributed processing for Nemotron VL model" )
245+ # Set up minimal distributed environment
246+ os .environ .setdefault ("MASTER_ADDR" , "127.0.0.1" )
247+ os .environ .setdefault ("MASTER_PORT" , "29500" )
248+ os .environ .setdefault ("RANK" , "0" )
249+ os .environ .setdefault ("WORLD_SIZE" , "1" )
250+
251+ if dist .is_available () and not dist .is_initialized ():
252+ try :
253+ dist .init_process_group (
254+ backend = "nccl" if torch .cuda .is_available () else "gloo" ,
255+ rank = 0 ,
256+ world_size = 1 ,
257+ )
258+ except Exception as dist_e :
259+ print (f"Failed to initialize distributed processing: { dist_e } " )
208260 try :
209261 # Try to create proper fake vision inputs for the VL model
210262 fake_kwargs = _create_fake_vl_inputs (model , fake_input )
@@ -222,8 +274,47 @@ def _output_hook(module, input, output):
222274 # For encoder-decoder models, we need to pass both the encoder and decoder input ids
223275 model (fake_input , decoder_input_ids = decoder_fake_input )
224276 elif is_vl_model :
225- # For VL models, use the fake vision inputs
226- model (** fake_kwargs )
277+ # For VL models, try to run optimization on just the language model part
278+ language_model = None
279+ if hasattr (model , "language_model" ):
280+ language_model = model .language_model
281+ print (
282+ "Found language_model attribute - running optimization on language model only"
283+ )
284+ elif hasattr (model , "model" ) and hasattr (model .model , "language_model" ):
285+ language_model = model .model .language_model
286+ print (
287+ "Found language_model in model.model - running optimization on language model only"
288+ )
289+
290+ if language_model is not None :
291+ # Run optimization on just the language model with the same input format as regular LLMs
292+ # Use the same fake_input tensor that regular LLMs use
293+ print (
294+ f"Running optimization on language model with fake_input shape: { fake_input .shape } "
295+ )
296+ try :
297+ language_model (fake_input )
298+ print ("✅ Language model optimization completed successfully" )
299+ except Exception as e :
300+ print (f"Language model optimization failed: { e } " )
301+ print ("Continuing with export..." )
302+ else :
303+ # Fallback: try full model with VL inputs
304+ print ("No separate language_model found - trying full VL model" )
305+ try :
306+ model (** fake_kwargs )
307+ print ("✅ Full VL model optimization completed successfully" )
308+ except (ValueError , RuntimeError , AttributeError ) as e :
309+ if (
310+ "Default process group has not been initialized" in str (e )
311+ or "must match the size of tensor" in str (e )
312+ or "'bool' object has no attribute 'sum'" in str (e )
313+ ):
314+ print (f"VL model forward pass failed: { e } " )
315+ print ("Skipping optimization for VL model - continuing with export" )
316+ else :
317+ raise
227318 else :
228319 model (fake_input )
229320
0 commit comments