Skip to content

Commit 54cb469

Browse files
committed
add proper detection and handling for nemotron VL model in ptq examples
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 3a43b1e commit 54cb469

File tree

2 files changed

+68
-13
lines changed

2 files changed

+68
-13
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,17 @@ def get_model(
185185
if device == "cpu":
186186
device_map = "cpu"
187187

188+
# Special handling for vision-language models that may have device mapping issues
189+
# Check if this is a VL model by looking at the model path
190+
is_vl_model = any(
191+
vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"]
192+
)
193+
if is_vl_model:
194+
print(
195+
"Detected vision-language model. Disabling automatic device mapping to avoid device_map errors."
196+
)
197+
device_map = None
198+
188199
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
189200
if attn_implementation is not None:
190201
config_kwargs["attn_implementation"] = attn_implementation
@@ -282,6 +293,12 @@ def get_model(
282293
**model_kwargs,
283294
)
284295
model.eval()
296+
297+
# If device_map was disabled (None), manually move model to target device
298+
if device_map is None and device != "cpu":
299+
print(f"Moving model to {device} device...")
300+
model = model.to(device)
301+
285302
if device == "cuda" and not is_model_on_gpu(model):
286303
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
287304

examples/llm_ptq/hf_ptq.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,16 @@ def main(args):
275275

276276
model_type = get_model_type(model)
277277

278+
# Special handling for Nemotron VL models that aren't detected by standard model type detection
279+
# For HF export, we want to keep vision unquantized, so we treat it as a regular language model
280+
# and only quantize the language components
281+
if model_type != "mllama" and is_multimodal_model(model):
282+
print(
283+
f"Detected multimodal model: {type(model).__name__}. "
284+
f"For HF export, will quantize language components only, keeping vision unquantized."
285+
)
286+
# Keep as regular model type to use text-only calibration
287+
278288
device = model.device
279289
if hasattr(model, "model"):
280290
device = model.model.device
@@ -458,20 +468,41 @@ def main(args):
458468
KV_QUANT_CFG_CHOICES,
459469
)
460470

471+
# For Nemotron VL models, disable quantization of vision components
472+
is_nemotron_vl = (
473+
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
474+
)
475+
if is_nemotron_vl:
476+
print("Disabling quantization for vision components in Nemotron VL model")
477+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
478+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
479+
# Also disable radio model components specifically
480+
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
481+
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
482+
461483
if not model_is_already_quantized or calibration_only:
462484
# Only run single sample for preview
463485
input_ids = next(iter(calib_dataloader))[
464486
"input_features" if model_type == "whisper" else "input_ids"
465487
][0:1]
466-
try:
467-
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
468-
except Exception as e:
469-
print(
470-
"Error during model generation. Please check if your transformers version is "
471-
"compatible with the model."
472-
)
473-
print(f"Error details: {e}")
474-
raise
488+
489+
# Skip preview generation for Nemotron VL models that require special handling
490+
is_nemotron_vl = (
491+
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
492+
)
493+
if is_nemotron_vl:
494+
print("Skipping preview generation for Nemotron VL model (requires image input)")
495+
generated_ids_before_ptq = None
496+
else:
497+
try:
498+
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
499+
except Exception as e:
500+
print(
501+
"Error during model generation. Please check if your transformers version is "
502+
"compatible with the model."
503+
)
504+
print(f"Error details: {e}")
505+
raise
475506
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
476507
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
477508

@@ -483,9 +514,13 @@ def main(args):
483514
# Run some samples
484515
torch.cuda.empty_cache()
485516
generated_ids_after_ptq = None
486-
if model_type != "llama4":
517+
if model_type != "llama4" and not is_nemotron_vl:
487518
# Our fake quantizer may not be fully compatible with torch.compile.
488519
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
520+
elif is_nemotron_vl:
521+
print(
522+
"Skipping post-quantization generation for Nemotron VL model (requires image input)"
523+
)
489524
else:
490525
warnings.warn(
491526
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -548,9 +583,12 @@ def output_decode(generated_ids, input_shape):
548583
# Save original model config and the processor config to the export path for VLMs.
549584
print(f"Saving original model config to {export_path}")
550585

551-
AutoConfig.from_pretrained(
552-
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
553-
).save_pretrained(export_path)
586+
config_kwargs = {"trust_remote_code": args.trust_remote_code}
587+
if args.attn_implementation is not None:
588+
config_kwargs["attn_implementation"] = args.attn_implementation
589+
AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(
590+
export_path
591+
)
554592

555593
# Try to save processor config if available
556594
try:

0 commit comments

Comments
 (0)