Skip to content

Commit f4134e3

Browse files
committed
remove distributed prcessing setup and vision input generation since we process language model part only in export
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent b90010c commit f4134e3

File tree

1 file changed

+4
-130
lines changed

1 file changed

+4
-130
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 4 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -75,85 +75,6 @@
7575
SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"]
7676

7777

78-
def _create_fake_vl_inputs(model, fake_input_ids):
79-
"""Create fake vision-language model inputs for export process.
80-
81-
Args:
82-
model: The VL model
83-
fake_input_ids: The fake text input IDs tensor
84-
85-
Returns:
86-
dict: Dictionary of fake inputs for the VL model
87-
"""
88-
import inspect
89-
90-
device = fake_input_ids.device
91-
batch_size = fake_input_ids.shape[0]
92-
93-
# Get the model's forward method signature to see what parameters it accepts
94-
forward_signature = inspect.signature(model.forward)
95-
accepted_params = set(forward_signature.parameters.keys())
96-
97-
# Create fake inputs based on common VL model patterns
98-
fake_inputs = {}
99-
100-
# Always include basic text inputs if accepted
101-
if "input_ids" in accepted_params:
102-
fake_inputs["input_ids"] = fake_input_ids
103-
if "attention_mask" in accepted_params:
104-
fake_inputs["attention_mask"] = torch.ones_like(fake_input_ids)
105-
106-
# Add vision-specific inputs based on model configuration and accepted parameters
107-
if hasattr(model.config, "vision_config") and "pixel_values" in accepted_params:
108-
vision_config = model.config.vision_config
109-
# Create fake pixel values based on vision config
110-
if hasattr(vision_config, "image_size"):
111-
image_size = vision_config.image_size
112-
else:
113-
image_size = 224 # Default size
114-
115-
if hasattr(vision_config, "num_channels"):
116-
num_channels = vision_config.num_channels
117-
else:
118-
num_channels = 3 # RGB default
119-
120-
# Create fake pixel values
121-
fake_inputs["pixel_values"] = torch.zeros(
122-
[batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device
123-
)
124-
125-
# Handle Nemotron-specific inputs based on testing results
126-
model_name = getattr(model, "name_or_path", "").lower()
127-
if "nemotron" in model_name:
128-
if "pixel_values" in accepted_params:
129-
# Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512]
130-
# This represents 14 image patches, each 512x512 pixels with 3 channels
131-
num_patches = 14
132-
patch_size = 512
133-
num_channels = 3
134-
135-
# Override any previous pixel_values with the correct Nemotron format
136-
# Use small random values instead of zeros to avoid NoneType issues
137-
fake_inputs["pixel_values"] = (
138-
torch.randn(
139-
[num_patches, num_channels, patch_size, patch_size],
140-
dtype=torch.float32,
141-
device=device,
142-
)
143-
* 0.1
144-
) # Small values to avoid extreme activations
145-
146-
if "image_flags" in accepted_params:
147-
# Based on testing, image_flags should have shape [14] (no batch dimension)
148-
# to match the [14, 256, 4096] tensor it's used to mask
149-
num_patches = 14 # From pixel_values shape [14, 3, 512, 512]
150-
fake_inputs["image_flags"] = torch.zeros(
151-
[num_patches], dtype=torch.long, device=device
152-
) # Shape [14] to match vision tensor dimensions
153-
154-
return fake_inputs
155-
156-
15778
def _is_enabled_quantizer(quantizer):
15879
if hasattr(quantizer, "is_enabled") and quantizer.is_enabled:
15980
return True
@@ -230,42 +151,8 @@ def _output_hook(module, input, output):
230151
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
231152
).to(model.device)
232153
elif is_vl_model:
233-
# For VL models, create proper fake vision inputs
234-
print("Detected VL model during export - creating fake vision inputs")
235-
236-
# Pre-emptively initialize distributed for Nemotron models that require it
237-
model_name = getattr(model, "name_or_path", "").lower()
238-
if "nemotron" in model_name:
239-
import os
240-
241-
import torch.distributed as dist
242-
243-
if not dist.is_available() or not dist.is_initialized():
244-
print("Pre-initializing distributed processing for Nemotron VL model")
245-
# Set up minimal distributed environment
246-
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
247-
os.environ.setdefault("MASTER_PORT", "29500")
248-
os.environ.setdefault("RANK", "0")
249-
os.environ.setdefault("WORLD_SIZE", "1")
250-
251-
if dist.is_available() and not dist.is_initialized():
252-
try:
253-
dist.init_process_group(
254-
backend="nccl" if torch.cuda.is_available() else "gloo",
255-
rank=0,
256-
world_size=1,
257-
)
258-
except Exception as dist_e:
259-
print(f"Failed to initialize distributed processing: {dist_e}")
260-
try:
261-
# Try to create proper fake vision inputs for the VL model
262-
fake_kwargs = _create_fake_vl_inputs(model, fake_input)
263-
except Exception as e:
264-
print(f"Failed to create fake VL inputs: {e}")
265-
print("Skipping requantize_resmooth_fused_llm_layers for VL model")
266-
for handle in handles:
267-
handle.remove()
268-
return
154+
# For VL models, run optimization on language model component only
155+
print("Detected VL model during export - optimizing language model component")
269156

270157
# Run forward pass so that all modules sharing the same input are collected using forward hook.
271158

@@ -300,21 +187,8 @@ def _output_hook(module, input, output):
300187
print(f"Language model optimization failed: {e}")
301188
print("Continuing with export...")
302189
else:
303-
# Fallback: try full model with VL inputs
304-
print("No separate language_model found - trying full VL model")
305-
try:
306-
model(**fake_kwargs)
307-
print("✅ Full VL model optimization completed successfully")
308-
except (ValueError, RuntimeError, AttributeError) as e:
309-
if (
310-
"Default process group has not been initialized" in str(e)
311-
or "must match the size of tensor" in str(e)
312-
or "'bool' object has no attribute 'sum'" in str(e)
313-
):
314-
print(f"VL model forward pass failed: {e}")
315-
print("Skipping optimization for VL model - continuing with export")
316-
else:
317-
raise
190+
print("Warning: No language_model found in VL model - skipping optimization")
191+
print("This is unexpected for most VL models")
318192
else:
319193
model(fake_input)
320194

0 commit comments

Comments
 (0)