Skip to content

Commit c71b661

Browse files
committed
create fake vl inputs in export for nemotron VL model
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 32bdfa9 commit c71b661

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,58 @@
7373
SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"]
7474

7575

76+
def _create_fake_vl_inputs(model, fake_input_ids):
77+
"""Create fake vision-language model inputs for export process.
78+
79+
Args:
80+
model: The VL model
81+
fake_input_ids: The fake text input IDs tensor
82+
83+
Returns:
84+
dict: Dictionary of fake inputs for the VL model
85+
"""
86+
device = fake_input_ids.device
87+
batch_size = fake_input_ids.shape[0]
88+
89+
# Create fake inputs based on common VL model patterns
90+
fake_inputs = {
91+
"input_ids": fake_input_ids,
92+
"attention_mask": torch.ones_like(fake_input_ids),
93+
}
94+
95+
# Add vision-specific inputs based on model configuration
96+
if hasattr(model.config, "vision_config"):
97+
vision_config = model.config.vision_config
98+
# Create fake pixel values based on vision config
99+
if hasattr(vision_config, "image_size"):
100+
image_size = vision_config.image_size
101+
else:
102+
image_size = 224 # Default size
103+
104+
if hasattr(vision_config, "num_channels"):
105+
num_channels = vision_config.num_channels
106+
else:
107+
num_channels = 3 # RGB default
108+
109+
# Create fake pixel values
110+
fake_inputs["pixel_values"] = torch.zeros(
111+
[batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device
112+
)
113+
114+
# Handle Nemotron-specific inputs
115+
model_name = getattr(model, "name_or_path", "").lower()
116+
if "nemotron" in model_name:
117+
# Nemotron models may need specific image flags
118+
fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device)
119+
120+
# Some VL models need aspect ratio information
121+
fake_inputs["aspect_ratio_ids"] = None
122+
fake_inputs["aspect_ratio_mask"] = None
123+
fake_inputs["cross_attention_mask"] = None
124+
125+
return fake_inputs
126+
127+
76128
def _is_enabled_quantizer(quantizer):
77129
if hasattr(quantizer, "is_enabled") and quantizer.is_enabled:
78130
return True
@@ -131,6 +183,14 @@ def _output_hook(module, input, output):
131183
with torch.no_grad():
132184
fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device)
133185
decoder_fake_input = fake_input
186+
187+
# Check if this is a VL model that needs special input handling
188+
is_vl_model = (
189+
hasattr(model.config, "vision_config")
190+
or hasattr(model, "vision_model")
191+
or "nemotron" in getattr(model, "name_or_path", "").lower()
192+
)
193+
134194
if model_type.startswith("whisper"):
135195
# For Whisper models, we need to pass a fake input with the specific sequence length
136196
from transformers import AutoFeatureExtractor
@@ -139,13 +199,28 @@ def _output_hook(module, input, output):
139199
fake_input = torch.ones(
140200
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
141201
).to(model.device)
202+
elif is_vl_model:
203+
# For VL models, create proper fake vision inputs
204+
print("Detected VL model during export - creating fake vision inputs")
205+
try:
206+
# Try to create proper fake vision inputs for the VL model
207+
fake_kwargs = _create_fake_vl_inputs(model, fake_input)
208+
except Exception as e:
209+
print(f"Failed to create fake VL inputs: {e}")
210+
print("Skipping requantize_resmooth_fused_llm_layers for VL model")
211+
for handle in handles:
212+
handle.remove()
213+
return
142214

143215
# Run forward pass so that all modules sharing the same input are collected using forward hook.
144216

145217
with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}):
146218
if getattr(model.config, "is_encoder_decoder", False):
147219
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
148220
model(fake_input, decoder_input_ids=decoder_fake_input)
221+
elif is_vl_model:
222+
# For VL models, use the fake vision inputs
223+
model(**fake_kwargs)
149224
else:
150225
model(fake_input)
151226

0 commit comments

Comments
 (0)