Skip to content

Commit f40501d

Browse files
committed
update fake inputs generation, initialize distributed for Nemotron models
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent c71b661 commit f40501d

File tree

1 file changed

+106
-15
lines changed

1 file changed

+106
-15
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,26 @@ def _create_fake_vl_inputs(model, fake_input_ids):
8383
Returns:
8484
dict: Dictionary of fake inputs for the VL model
8585
"""
86+
import inspect
87+
8688
device = fake_input_ids.device
8789
batch_size = fake_input_ids.shape[0]
8890

91+
# Get the model's forward method signature to see what parameters it accepts
92+
forward_signature = inspect.signature(model.forward)
93+
accepted_params = set(forward_signature.parameters.keys())
94+
8995
# Create fake inputs based on common VL model patterns
90-
fake_inputs = {
91-
"input_ids": fake_input_ids,
92-
"attention_mask": torch.ones_like(fake_input_ids),
93-
}
96+
fake_inputs = {}
97+
98+
# Always include basic text inputs if accepted
99+
if "input_ids" in accepted_params:
100+
fake_inputs["input_ids"] = fake_input_ids
101+
if "attention_mask" in accepted_params:
102+
fake_inputs["attention_mask"] = torch.ones_like(fake_input_ids)
94103

95-
# Add vision-specific inputs based on model configuration
96-
if hasattr(model.config, "vision_config"):
104+
# Add vision-specific inputs based on model configuration and accepted parameters
105+
if hasattr(model.config, "vision_config") and "pixel_values" in accepted_params:
97106
vision_config = model.config.vision_config
98107
# Create fake pixel values based on vision config
99108
if hasattr(vision_config, "image_size"):
@@ -111,16 +120,34 @@ def _create_fake_vl_inputs(model, fake_input_ids):
111120
[batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device
112121
)
113122

114-
# Handle Nemotron-specific inputs
123+
# Handle Nemotron-specific inputs based on testing results
115124
model_name = getattr(model, "name_or_path", "").lower()
116125
if "nemotron" in model_name:
117-
# Nemotron models may need specific image flags
118-
fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device)
126+
if "pixel_values" in accepted_params:
127+
# Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512]
128+
# This represents 14 image patches, each 512x512 pixels with 3 channels
129+
num_patches = 14
130+
patch_size = 512
131+
num_channels = 3
132+
133+
# Override any previous pixel_values with the correct Nemotron format
134+
# Use small random values instead of zeros to avoid NoneType issues
135+
fake_inputs["pixel_values"] = (
136+
torch.randn(
137+
[num_patches, num_channels, patch_size, patch_size],
138+
dtype=torch.float32,
139+
device=device,
140+
)
141+
* 0.1
142+
) # Small values to avoid extreme activations
119143

120-
# Some VL models need aspect ratio information
121-
fake_inputs["aspect_ratio_ids"] = None
122-
fake_inputs["aspect_ratio_mask"] = None
123-
fake_inputs["cross_attention_mask"] = None
144+
if "image_flags" in accepted_params:
145+
# Based on testing, image_flags should have shape [14] (no batch dimension)
146+
# to match the [14, 256, 4096] tensor it's used to mask
147+
num_patches = 14 # From pixel_values shape [14, 3, 512, 512]
148+
fake_inputs["image_flags"] = torch.zeros(
149+
[num_patches], dtype=torch.long, device=device
150+
) # Shape [14] to match vision tensor dimensions
124151

125152
return fake_inputs
126153

@@ -202,6 +229,31 @@ def _output_hook(module, input, output):
202229
elif is_vl_model:
203230
# For VL models, create proper fake vision inputs
204231
print("Detected VL model during export - creating fake vision inputs")
232+
233+
# Pre-emptively initialize distributed for Nemotron models that require it
234+
model_name = getattr(model, "name_or_path", "").lower()
235+
if "nemotron" in model_name:
236+
import os
237+
238+
import torch.distributed as dist
239+
240+
if not dist.is_available() or not dist.is_initialized():
241+
print("Pre-initializing distributed processing for Nemotron VL model")
242+
# Set up minimal distributed environment
243+
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
244+
os.environ.setdefault("MASTER_PORT", "29500")
245+
os.environ.setdefault("RANK", "0")
246+
os.environ.setdefault("WORLD_SIZE", "1")
247+
248+
if dist.is_available() and not dist.is_initialized():
249+
try:
250+
dist.init_process_group(
251+
backend="nccl" if torch.cuda.is_available() else "gloo",
252+
rank=0,
253+
world_size=1,
254+
)
255+
except Exception as dist_e:
256+
print(f"Failed to initialize distributed processing: {dist_e}")
205257
try:
206258
# Try to create proper fake vision inputs for the VL model
207259
fake_kwargs = _create_fake_vl_inputs(model, fake_input)
@@ -219,8 +271,47 @@ def _output_hook(module, input, output):
219271
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
220272
model(fake_input, decoder_input_ids=decoder_fake_input)
221273
elif is_vl_model:
222-
# For VL models, use the fake vision inputs
223-
model(**fake_kwargs)
274+
# For VL models, try to run optimization on just the language model part
275+
language_model = None
276+
if hasattr(model, "language_model"):
277+
language_model = model.language_model
278+
print(
279+
"Found language_model attribute - running optimization on language model only"
280+
)
281+
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
282+
language_model = model.model.language_model
283+
print(
284+
"Found language_model in model.model - running optimization on language model only"
285+
)
286+
287+
if language_model is not None:
288+
# Run optimization on just the language model with the same input format as regular LLMs
289+
# Use the same fake_input tensor that regular LLMs use
290+
print(
291+
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
292+
)
293+
try:
294+
language_model(fake_input)
295+
print("✅ Language model optimization completed successfully")
296+
except Exception as e:
297+
print(f"Language model optimization failed: {e}")
298+
print("Continuing with export...")
299+
else:
300+
# Fallback: try full model with VL inputs
301+
print("No separate language_model found - trying full VL model")
302+
try:
303+
model(**fake_kwargs)
304+
print("✅ Full VL model optimization completed successfully")
305+
except (ValueError, RuntimeError, AttributeError) as e:
306+
if (
307+
"Default process group has not been initialized" in str(e)
308+
or "must match the size of tensor" in str(e)
309+
or "'bool' object has no attribute 'sum'" in str(e)
310+
):
311+
print(f"VL model forward pass failed: {e}")
312+
print("Skipping optimization for VL model - continuing with export")
313+
else:
314+
raise
224315
else:
225316
model(fake_input)
226317

0 commit comments

Comments
 (0)