Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def get_model(
if device == "cpu":
device_map = "cpu"

# Special handling for vision-language models that may have device mapping issues
# Check if this is a VL model by looking at the model path
is_vl_model = any(
vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"]
)
if is_vl_model:
print(
"Detected vision-language model. Disabling automatic device mapping to avoid device_map errors."
)
device_map = None

config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
if attn_implementation is not None:
config_kwargs["attn_implementation"] = attn_implementation
Expand Down Expand Up @@ -235,6 +246,12 @@ def get_model(
**model_kwargs,
)
model.eval()

# If device_map was disabled (None), manually move model to target device
if device_map is None and device != "cpu":
print(f"Moving model to {device} device...")
model = model.to(device)

if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

Expand Down
68 changes: 53 additions & 15 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ def main(args):

model_type = get_model_type(model)

# Special handling for Nemotron VL models that aren't detected by standard model type detection
# For HF export, we want to keep vision unquantized, so we treat it as a regular language model
# and only quantize the language components
if model_type != "mllama" and is_multimodal_model(model):
print(
f"Detected multimodal model: {type(model).__name__}. "
f"For HF export, will quantize language components only, keeping vision unquantized."
)
# Keep as regular model type to use text-only calibration

device = model.device
if hasattr(model, "model"):
device = model.model.device
Expand Down Expand Up @@ -487,20 +497,41 @@ def main(args):
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
)

# For Nemotron VL models, disable quantization of vision components
is_nemotron_vl = (
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
)
if is_nemotron_vl:
print("Disabling quantization for vision components in Nemotron VL model")
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
# Also disable radio model components specifically
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}

if not model_is_already_quantized or calibration_only:
# Only run single sample for preview
input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
try:
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
except Exception as e:
print(
"Error during model generation. Please check if your transformers version is "
"compatible with the model."
)
print(f"Error details: {e}")
raise

# Skip preview generation for Nemotron VL models that require special handling
is_nemotron_vl = (
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
)
if is_nemotron_vl:
print("Skipping preview generation for Nemotron VL model (requires image input)")
generated_ids_before_ptq = None
else:
try:
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
except Exception as e:
print(
"Error during model generation. Please check if your transformers version is "
"compatible with the model."
)
print(f"Error details: {e}")
raise
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

Expand All @@ -512,9 +543,13 @@ def main(args):
# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if model_type != "llama4":
if model_type != "llama4" and not is_nemotron_vl:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
elif is_nemotron_vl:
print(
"Skipping post-quantization generation for Nemotron VL model (requires image input)"
)
else:
warnings.warn(
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
Expand Down Expand Up @@ -577,9 +612,12 @@ def output_decode(generated_ids, input_shape):
# Save original model config and the processor config to the export path for VLMs.
print(f"Saving original model config to {export_path}")

AutoConfig.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
config_kwargs = {"trust_remote_code": args.trust_remote_code}
if args.attn_implementation is not None:
config_kwargs["attn_implementation"] = args.attn_implementation
AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(
export_path
)

# Try to save processor config if available
try:
Expand Down Expand Up @@ -758,10 +796,10 @@ def output_decode(generated_ids, input_shape):
parser.add_argument(
"--attn_implementation",
help=(
"Specify the attention implementation to use."
"Specify the attention implementation to use. "
"This arg will be passed to the HF model loading if specified."
),
default=None,
default="eager",
type=str,
)

Expand Down
40 changes: 40 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def _output_hook(module, input, output):
with torch.no_grad():
fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device)
decoder_fake_input = fake_input

# Check if this is a VL model that needs special input handling
is_vl_model = (
hasattr(model.config, "vision_config")
or hasattr(model, "vision_model")
or "nemotron" in getattr(model, "name_or_path", "").lower()
)

if model_type.startswith("whisper"):
# For Whisper models, we need to pass a fake input with the specific sequence length
from transformers import AutoFeatureExtractor
Expand All @@ -139,13 +147,45 @@ def _output_hook(module, input, output):
fake_input = torch.ones(
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
).to(model.device)
elif is_vl_model:
# For VL models, run optimization on language model component only
print("Detected VL model during export - optimizing language model component")

# Run forward pass so that all modules sharing the same input are collected using forward hook.

with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}):
if getattr(model.config, "is_encoder_decoder", False):
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
model(fake_input, decoder_input_ids=decoder_fake_input)
elif is_vl_model:
# For VL models, try to run optimization on just the language model part
language_model = None
if hasattr(model, "language_model"):
language_model = model.language_model
print(
"Found language_model attribute - running optimization on language model only"
)
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
language_model = model.model.language_model
print(
"Found language_model in model.model - running optimization on language model only"
)

if language_model is not None:
# Run optimization on just the language model with the same input format as regular LLMs
# Use the same fake_input tensor that regular LLMs use
print(
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
)
try:
language_model(fake_input)
print("✅ Language model optimization completed successfully")
except Exception as e:
print(f"Language model optimization failed: {e}")
print("Continuing with export...")
else:
print("Warning: No language_model found in VL model - skipping optimization")
print("This is unexpected for most VL models")
else:
model(fake_input)

Expand Down
Loading