Skip to content

Commit 54ec55a

Browse files
committed
create fake vl inputs in export for nemotron VL model
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 54cb469 commit 54ec55a

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,58 @@
7575
SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"]
7676

7777

78+
def _create_fake_vl_inputs(model, fake_input_ids):
79+
"""Create fake vision-language model inputs for export process.
80+
81+
Args:
82+
model: The VL model
83+
fake_input_ids: The fake text input IDs tensor
84+
85+
Returns:
86+
dict: Dictionary of fake inputs for the VL model
87+
"""
88+
device = fake_input_ids.device
89+
batch_size = fake_input_ids.shape[0]
90+
91+
# Create fake inputs based on common VL model patterns
92+
fake_inputs = {
93+
"input_ids": fake_input_ids,
94+
"attention_mask": torch.ones_like(fake_input_ids),
95+
}
96+
97+
# Add vision-specific inputs based on model configuration
98+
if hasattr(model.config, "vision_config"):
99+
vision_config = model.config.vision_config
100+
# Create fake pixel values based on vision config
101+
if hasattr(vision_config, "image_size"):
102+
image_size = vision_config.image_size
103+
else:
104+
image_size = 224 # Default size
105+
106+
if hasattr(vision_config, "num_channels"):
107+
num_channels = vision_config.num_channels
108+
else:
109+
num_channels = 3 # RGB default
110+
111+
# Create fake pixel values
112+
fake_inputs["pixel_values"] = torch.zeros(
113+
[batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device
114+
)
115+
116+
# Handle Nemotron-specific inputs
117+
model_name = getattr(model, "name_or_path", "").lower()
118+
if "nemotron" in model_name:
119+
# Nemotron models may need specific image flags
120+
fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device)
121+
122+
# Some VL models need aspect ratio information
123+
fake_inputs["aspect_ratio_ids"] = None
124+
fake_inputs["aspect_ratio_mask"] = None
125+
fake_inputs["cross_attention_mask"] = None
126+
127+
return fake_inputs
128+
129+
78130
def _is_enabled_quantizer(quantizer):
79131
if hasattr(quantizer, "is_enabled") and quantizer.is_enabled:
80132
return True
@@ -134,6 +186,14 @@ def _output_hook(module, input, output):
134186
with torch.no_grad():
135187
fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device)
136188
decoder_fake_input = fake_input
189+
190+
# Check if this is a VL model that needs special input handling
191+
is_vl_model = (
192+
hasattr(model.config, "vision_config")
193+
or hasattr(model, "vision_model")
194+
or "nemotron" in getattr(model, "name_or_path", "").lower()
195+
)
196+
137197
if model_type.startswith("whisper"):
138198
# For Whisper models, we need to pass a fake input with the specific sequence length
139199
from transformers import AutoFeatureExtractor
@@ -142,13 +202,28 @@ def _output_hook(module, input, output):
142202
fake_input = torch.ones(
143203
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
144204
).to(model.device)
205+
elif is_vl_model:
206+
# For VL models, create proper fake vision inputs
207+
print("Detected VL model during export - creating fake vision inputs")
208+
try:
209+
# Try to create proper fake vision inputs for the VL model
210+
fake_kwargs = _create_fake_vl_inputs(model, fake_input)
211+
except Exception as e:
212+
print(f"Failed to create fake VL inputs: {e}")
213+
print("Skipping requantize_resmooth_fused_llm_layers for VL model")
214+
for handle in handles:
215+
handle.remove()
216+
return
145217

146218
# Run forward pass so that all modules sharing the same input are collected using forward hook.
147219

148220
with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}):
149221
if getattr(model.config, "is_encoder_decoder", False):
150222
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
151223
model(fake_input, decoder_input_ids=decoder_fake_input)
224+
elif is_vl_model:
225+
# For VL models, use the fake vision inputs
226+
model(**fake_kwargs)
152227
else:
153228
model(fake_input)
154229

0 commit comments

Comments
 (0)