Skip to content

Commit b90010c

Browse files
committed
update fake inputs generation, initialize distributed for Nemotron models
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 54ec55a commit b90010c

File tree

1 file changed

+106
-15
lines changed

1 file changed

+106
-15
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,26 @@ def _create_fake_vl_inputs(model, fake_input_ids):
8585
Returns:
8686
dict: Dictionary of fake inputs for the VL model
8787
"""
88+
import inspect
89+
8890
device = fake_input_ids.device
8991
batch_size = fake_input_ids.shape[0]
9092

93+
# Get the model's forward method signature to see what parameters it accepts
94+
forward_signature = inspect.signature(model.forward)
95+
accepted_params = set(forward_signature.parameters.keys())
96+
9197
# Create fake inputs based on common VL model patterns
92-
fake_inputs = {
93-
"input_ids": fake_input_ids,
94-
"attention_mask": torch.ones_like(fake_input_ids),
95-
}
98+
fake_inputs = {}
99+
100+
# Always include basic text inputs if accepted
101+
if "input_ids" in accepted_params:
102+
fake_inputs["input_ids"] = fake_input_ids
103+
if "attention_mask" in accepted_params:
104+
fake_inputs["attention_mask"] = torch.ones_like(fake_input_ids)
96105

97-
# Add vision-specific inputs based on model configuration
98-
if hasattr(model.config, "vision_config"):
106+
# Add vision-specific inputs based on model configuration and accepted parameters
107+
if hasattr(model.config, "vision_config") and "pixel_values" in accepted_params:
99108
vision_config = model.config.vision_config
100109
# Create fake pixel values based on vision config
101110
if hasattr(vision_config, "image_size"):
@@ -113,16 +122,34 @@ def _create_fake_vl_inputs(model, fake_input_ids):
113122
[batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device
114123
)
115124

116-
# Handle Nemotron-specific inputs
125+
# Handle Nemotron-specific inputs based on testing results
117126
model_name = getattr(model, "name_or_path", "").lower()
118127
if "nemotron" in model_name:
119-
# Nemotron models may need specific image flags
120-
fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device)
128+
if "pixel_values" in accepted_params:
129+
# Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512]
130+
# This represents 14 image patches, each 512x512 pixels with 3 channels
131+
num_patches = 14
132+
patch_size = 512
133+
num_channels = 3
134+
135+
# Override any previous pixel_values with the correct Nemotron format
136+
# Use small random values instead of zeros to avoid NoneType issues
137+
fake_inputs["pixel_values"] = (
138+
torch.randn(
139+
[num_patches, num_channels, patch_size, patch_size],
140+
dtype=torch.float32,
141+
device=device,
142+
)
143+
* 0.1
144+
) # Small values to avoid extreme activations
121145

122-
# Some VL models need aspect ratio information
123-
fake_inputs["aspect_ratio_ids"] = None
124-
fake_inputs["aspect_ratio_mask"] = None
125-
fake_inputs["cross_attention_mask"] = None
146+
if "image_flags" in accepted_params:
147+
# Based on testing, image_flags should have shape [14] (no batch dimension)
148+
# to match the [14, 256, 4096] tensor it's used to mask
149+
num_patches = 14 # From pixel_values shape [14, 3, 512, 512]
150+
fake_inputs["image_flags"] = torch.zeros(
151+
[num_patches], dtype=torch.long, device=device
152+
) # Shape [14] to match vision tensor dimensions
126153

127154
return fake_inputs
128155

@@ -205,6 +232,31 @@ def _output_hook(module, input, output):
205232
elif is_vl_model:
206233
# For VL models, create proper fake vision inputs
207234
print("Detected VL model during export - creating fake vision inputs")
235+
236+
# Pre-emptively initialize distributed for Nemotron models that require it
237+
model_name = getattr(model, "name_or_path", "").lower()
238+
if "nemotron" in model_name:
239+
import os
240+
241+
import torch.distributed as dist
242+
243+
if not dist.is_available() or not dist.is_initialized():
244+
print("Pre-initializing distributed processing for Nemotron VL model")
245+
# Set up minimal distributed environment
246+
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
247+
os.environ.setdefault("MASTER_PORT", "29500")
248+
os.environ.setdefault("RANK", "0")
249+
os.environ.setdefault("WORLD_SIZE", "1")
250+
251+
if dist.is_available() and not dist.is_initialized():
252+
try:
253+
dist.init_process_group(
254+
backend="nccl" if torch.cuda.is_available() else "gloo",
255+
rank=0,
256+
world_size=1,
257+
)
258+
except Exception as dist_e:
259+
print(f"Failed to initialize distributed processing: {dist_e}")
208260
try:
209261
# Try to create proper fake vision inputs for the VL model
210262
fake_kwargs = _create_fake_vl_inputs(model, fake_input)
@@ -222,8 +274,47 @@ def _output_hook(module, input, output):
222274
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
223275
model(fake_input, decoder_input_ids=decoder_fake_input)
224276
elif is_vl_model:
225-
# For VL models, use the fake vision inputs
226-
model(**fake_kwargs)
277+
# For VL models, try to run optimization on just the language model part
278+
language_model = None
279+
if hasattr(model, "language_model"):
280+
language_model = model.language_model
281+
print(
282+
"Found language_model attribute - running optimization on language model only"
283+
)
284+
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
285+
language_model = model.model.language_model
286+
print(
287+
"Found language_model in model.model - running optimization on language model only"
288+
)
289+
290+
if language_model is not None:
291+
# Run optimization on just the language model with the same input format as regular LLMs
292+
# Use the same fake_input tensor that regular LLMs use
293+
print(
294+
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
295+
)
296+
try:
297+
language_model(fake_input)
298+
print("✅ Language model optimization completed successfully")
299+
except Exception as e:
300+
print(f"Language model optimization failed: {e}")
301+
print("Continuing with export...")
302+
else:
303+
# Fallback: try full model with VL inputs
304+
print("No separate language_model found - trying full VL model")
305+
try:
306+
model(**fake_kwargs)
307+
print("✅ Full VL model optimization completed successfully")
308+
except (ValueError, RuntimeError, AttributeError) as e:
309+
if (
310+
"Default process group has not been initialized" in str(e)
311+
or "must match the size of tensor" in str(e)
312+
or "'bool' object has no attribute 'sum'" in str(e)
313+
):
314+
print(f"VL model forward pass failed: {e}")
315+
print("Skipping optimization for VL model - continuing with export")
316+
else:
317+
raise
227318
else:
228319
model(fake_input)
229320

0 commit comments

Comments
 (0)