@@ -83,17 +83,26 @@ def _create_fake_vl_inputs(model, fake_input_ids):
83
83
Returns:
84
84
dict: Dictionary of fake inputs for the VL model
85
85
"""
86
+ import inspect
87
+
86
88
device = fake_input_ids .device
87
89
batch_size = fake_input_ids .shape [0 ]
88
90
91
+ # Get the model's forward method signature to see what parameters it accepts
92
+ forward_signature = inspect .signature (model .forward )
93
+ accepted_params = set (forward_signature .parameters .keys ())
94
+
89
95
# Create fake inputs based on common VL model patterns
90
- fake_inputs = {
91
- "input_ids" : fake_input_ids ,
92
- "attention_mask" : torch .ones_like (fake_input_ids ),
93
- }
96
+ fake_inputs = {}
97
+
98
+ # Always include basic text inputs if accepted
99
+ if "input_ids" in accepted_params :
100
+ fake_inputs ["input_ids" ] = fake_input_ids
101
+ if "attention_mask" in accepted_params :
102
+ fake_inputs ["attention_mask" ] = torch .ones_like (fake_input_ids )
94
103
95
- # Add vision-specific inputs based on model configuration
96
- if hasattr (model .config , "vision_config" ):
104
+ # Add vision-specific inputs based on model configuration and accepted parameters
105
+ if hasattr (model .config , "vision_config" ) and "pixel_values" in accepted_params :
97
106
vision_config = model .config .vision_config
98
107
# Create fake pixel values based on vision config
99
108
if hasattr (vision_config , "image_size" ):
@@ -111,16 +120,34 @@ def _create_fake_vl_inputs(model, fake_input_ids):
111
120
[batch_size , num_channels , image_size , image_size ], dtype = torch .float32 , device = device
112
121
)
113
122
114
- # Handle Nemotron-specific inputs
123
+ # Handle Nemotron-specific inputs based on testing results
115
124
model_name = getattr (model , "name_or_path" , "" ).lower ()
116
125
if "nemotron" in model_name :
117
- # Nemotron models may need specific image flags
118
- fake_inputs ["image_flags" ] = torch .zeros ([batch_size , 1 ], dtype = torch .long , device = device )
126
+ if "pixel_values" in accepted_params :
127
+ # Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512]
128
+ # This represents 14 image patches, each 512x512 pixels with 3 channels
129
+ num_patches = 14
130
+ patch_size = 512
131
+ num_channels = 3
132
+
133
+ # Override any previous pixel_values with the correct Nemotron format
134
+ # Use small random values instead of zeros to avoid NoneType issues
135
+ fake_inputs ["pixel_values" ] = (
136
+ torch .randn (
137
+ [num_patches , num_channels , patch_size , patch_size ],
138
+ dtype = torch .float32 ,
139
+ device = device ,
140
+ )
141
+ * 0.1
142
+ ) # Small values to avoid extreme activations
119
143
120
- # Some VL models need aspect ratio information
121
- fake_inputs ["aspect_ratio_ids" ] = None
122
- fake_inputs ["aspect_ratio_mask" ] = None
123
- fake_inputs ["cross_attention_mask" ] = None
144
+ if "image_flags" in accepted_params :
145
+ # Based on testing, image_flags should have shape [14] (no batch dimension)
146
+ # to match the [14, 256, 4096] tensor it's used to mask
147
+ num_patches = 14 # From pixel_values shape [14, 3, 512, 512]
148
+ fake_inputs ["image_flags" ] = torch .zeros (
149
+ [num_patches ], dtype = torch .long , device = device
150
+ ) # Shape [14] to match vision tensor dimensions
124
151
125
152
return fake_inputs
126
153
@@ -202,6 +229,31 @@ def _output_hook(module, input, output):
202
229
elif is_vl_model :
203
230
# For VL models, create proper fake vision inputs
204
231
print ("Detected VL model during export - creating fake vision inputs" )
232
+
233
+ # Pre-emptively initialize distributed for Nemotron models that require it
234
+ model_name = getattr (model , "name_or_path" , "" ).lower ()
235
+ if "nemotron" in model_name :
236
+ import os
237
+
238
+ import torch .distributed as dist
239
+
240
+ if not dist .is_available () or not dist .is_initialized ():
241
+ print ("Pre-initializing distributed processing for Nemotron VL model" )
242
+ # Set up minimal distributed environment
243
+ os .environ .setdefault ("MASTER_ADDR" , "127.0.0.1" )
244
+ os .environ .setdefault ("MASTER_PORT" , "29500" )
245
+ os .environ .setdefault ("RANK" , "0" )
246
+ os .environ .setdefault ("WORLD_SIZE" , "1" )
247
+
248
+ if dist .is_available () and not dist .is_initialized ():
249
+ try :
250
+ dist .init_process_group (
251
+ backend = "nccl" if torch .cuda .is_available () else "gloo" ,
252
+ rank = 0 ,
253
+ world_size = 1 ,
254
+ )
255
+ except Exception as dist_e :
256
+ print (f"Failed to initialize distributed processing: { dist_e } " )
205
257
try :
206
258
# Try to create proper fake vision inputs for the VL model
207
259
fake_kwargs = _create_fake_vl_inputs (model , fake_input )
@@ -219,8 +271,47 @@ def _output_hook(module, input, output):
219
271
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
220
272
model (fake_input , decoder_input_ids = decoder_fake_input )
221
273
elif is_vl_model :
222
- # For VL models, use the fake vision inputs
223
- model (** fake_kwargs )
274
+ # For VL models, try to run optimization on just the language model part
275
+ language_model = None
276
+ if hasattr (model , "language_model" ):
277
+ language_model = model .language_model
278
+ print (
279
+ "Found language_model attribute - running optimization on language model only"
280
+ )
281
+ elif hasattr (model , "model" ) and hasattr (model .model , "language_model" ):
282
+ language_model = model .model .language_model
283
+ print (
284
+ "Found language_model in model.model - running optimization on language model only"
285
+ )
286
+
287
+ if language_model is not None :
288
+ # Run optimization on just the language model with the same input format as regular LLMs
289
+ # Use the same fake_input tensor that regular LLMs use
290
+ print (
291
+ f"Running optimization on language model with fake_input shape: { fake_input .shape } "
292
+ )
293
+ try :
294
+ language_model (fake_input )
295
+ print ("✅ Language model optimization completed successfully" )
296
+ except Exception as e :
297
+ print (f"Language model optimization failed: { e } " )
298
+ print ("Continuing with export..." )
299
+ else :
300
+ # Fallback: try full model with VL inputs
301
+ print ("No separate language_model found - trying full VL model" )
302
+ try :
303
+ model (** fake_kwargs )
304
+ print ("✅ Full VL model optimization completed successfully" )
305
+ except (ValueError , RuntimeError , AttributeError ) as e :
306
+ if (
307
+ "Default process group has not been initialized" in str (e )
308
+ or "must match the size of tensor" in str (e )
309
+ or "'bool' object has no attribute 'sum'" in str (e )
310
+ ):
311
+ print (f"VL model forward pass failed: { e } " )
312
+ print ("Skipping optimization for VL model - continuing with export" )
313
+ else :
314
+ raise
224
315
else :
225
316
model (fake_input )
226
317
0 commit comments