Skip to content

Commit 32bdfa9

Browse files
committed
add proper detection and handling for nemotron VL model in ptq examples
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 98772b9 commit 32bdfa9

File tree

2 files changed

+68
-13
lines changed

2 files changed

+68
-13
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,17 @@ def get_model(
136136
if device == "cpu":
137137
device_map = "cpu"
138138

139+
# Special handling for vision-language models that may have device mapping issues
140+
# Check if this is a VL model by looking at the model path
141+
is_vl_model = any(
142+
vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"]
143+
)
144+
if is_vl_model:
145+
print(
146+
"Detected vision-language model. Disabling automatic device mapping to avoid device_map errors."
147+
)
148+
device_map = None
149+
139150
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
140151
if attn_implementation is not None:
141152
config_kwargs["attn_implementation"] = attn_implementation
@@ -235,6 +246,12 @@ def get_model(
235246
**model_kwargs,
236247
)
237248
model.eval()
249+
250+
# If device_map was disabled (None), manually move model to target device
251+
if device_map is None and device != "cpu":
252+
print(f"Moving model to {device} device...")
253+
model = model.to(device)
254+
238255
if device == "cuda" and not is_model_on_gpu(model):
239256
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
240257

examples/llm_ptq/hf_ptq.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ def main(args):
281281

282282
model_type = get_model_type(model)
283283

284+
# Special handling for Nemotron VL models that aren't detected by standard model type detection
285+
# For HF export, we want to keep vision unquantized, so we treat it as a regular language model
286+
# and only quantize the language components
287+
if model_type != "mllama" and is_multimodal_model(model):
288+
print(
289+
f"Detected multimodal model: {type(model).__name__}. "
290+
f"For HF export, will quantize language components only, keeping vision unquantized."
291+
)
292+
# Keep as regular model type to use text-only calibration
293+
284294
device = model.device
285295
if hasattr(model, "model"):
286296
device = model.model.device
@@ -487,20 +497,41 @@ def main(args):
487497
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
488498
)
489499

500+
# For Nemotron VL models, disable quantization of vision components
501+
is_nemotron_vl = (
502+
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
503+
)
504+
if is_nemotron_vl:
505+
print("Disabling quantization for vision components in Nemotron VL model")
506+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
507+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
508+
# Also disable radio model components specifically
509+
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
510+
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
511+
490512
if not model_is_already_quantized or calibration_only:
491513
# Only run single sample for preview
492514
input_ids = next(iter(calib_dataloader))[
493515
"input_features" if model_type == "whisper" else "input_ids"
494516
][0:1]
495-
try:
496-
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
497-
except Exception as e:
498-
print(
499-
"Error during model generation. Please check if your transformers version is "
500-
"compatible with the model."
501-
)
502-
print(f"Error details: {e}")
503-
raise
517+
518+
# Skip preview generation for Nemotron VL models that require special handling
519+
is_nemotron_vl = (
520+
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
521+
)
522+
if is_nemotron_vl:
523+
print("Skipping preview generation for Nemotron VL model (requires image input)")
524+
generated_ids_before_ptq = None
525+
else:
526+
try:
527+
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
528+
except Exception as e:
529+
print(
530+
"Error during model generation. Please check if your transformers version is "
531+
"compatible with the model."
532+
)
533+
print(f"Error details: {e}")
534+
raise
504535
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
505536
print("Applying nvfp4 quantization (MoE only) for gpt-oss")
506537

@@ -512,9 +543,13 @@ def main(args):
512543
# Run some samples
513544
torch.cuda.empty_cache()
514545
generated_ids_after_ptq = None
515-
if model_type != "llama4":
546+
if model_type != "llama4" and not is_nemotron_vl:
516547
# Our fake quantizer may not be fully compatible with torch.compile.
517548
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
549+
elif is_nemotron_vl:
550+
print(
551+
"Skipping post-quantization generation for Nemotron VL model (requires image input)"
552+
)
518553
else:
519554
warnings.warn(
520555
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -577,9 +612,12 @@ def output_decode(generated_ids, input_shape):
577612
# Save original model config and the processor config to the export path for VLMs.
578613
print(f"Saving original model config to {export_path}")
579614

580-
AutoConfig.from_pretrained(
581-
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
582-
).save_pretrained(export_path)
615+
config_kwargs = {"trust_remote_code": args.trust_remote_code}
616+
if args.attn_implementation is not None:
617+
config_kwargs["attn_implementation"] = args.attn_implementation
618+
AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(
619+
export_path
620+
)
583621

584622
# Try to save processor config if available
585623
try:

0 commit comments

Comments
 (0)