@@ -281,6 +281,16 @@ def main(args):
281
281
282
282
model_type = get_model_type (model )
283
283
284
+ # Special handling for Nemotron VL models that aren't detected by standard model type detection
285
+ # For HF export, we want to keep vision unquantized, so we treat it as a regular language model
286
+ # and only quantize the language components
287
+ if model_type != "mllama" and is_multimodal_model (model ):
288
+ print (
289
+ f"Detected multimodal model: { type (model ).__name__ } . "
290
+ f"For HF export, will quantize language components only, keeping vision unquantized."
291
+ )
292
+ # Keep as regular model type to use text-only calibration
293
+
284
294
device = model .device
285
295
if hasattr (model , "model" ):
286
296
device = model .model .device
@@ -487,20 +497,41 @@ def main(args):
487
497
"Please set the default input_mode to InputMode.LANGUAGE before quantizing."
488
498
)
489
499
500
+ # For Nemotron VL models, disable quantization of vision components
501
+ is_nemotron_vl = (
502
+ "nemotron" in args .pyt_ckpt_path .lower () and "vl" in args .pyt_ckpt_path .lower ()
503
+ )
504
+ if is_nemotron_vl :
505
+ print ("Disabling quantization for vision components in Nemotron VL model" )
506
+ quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
507
+ quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
508
+ # Also disable radio model components specifically
509
+ quant_cfg ["quant_cfg" ]["*radio*" ] = {"enable" : False }
510
+ quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
511
+
490
512
if not model_is_already_quantized or calibration_only :
491
513
# Only run single sample for preview
492
514
input_ids = next (iter (calib_dataloader ))[
493
515
"input_features" if model_type == "whisper" else "input_ids"
494
516
][0 :1 ]
495
- try :
496
- generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
497
- except Exception as e :
498
- print (
499
- "Error during model generation. Please check if your transformers version is "
500
- "compatible with the model."
501
- )
502
- print (f"Error details: { e } " )
503
- raise
517
+
518
+ # Skip preview generation for Nemotron VL models that require special handling
519
+ is_nemotron_vl = (
520
+ "nemotron" in args .pyt_ckpt_path .lower () and "vl" in args .pyt_ckpt_path .lower ()
521
+ )
522
+ if is_nemotron_vl :
523
+ print ("Skipping preview generation for Nemotron VL model (requires image input)" )
524
+ generated_ids_before_ptq = None
525
+ else :
526
+ try :
527
+ generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
528
+ except Exception as e :
529
+ print (
530
+ "Error during model generation. Please check if your transformers version is "
531
+ "compatible with the model."
532
+ )
533
+ print (f"Error details: { e } " )
534
+ raise
504
535
if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
505
536
print ("Applying nvfp4 quantization (MoE only) for gpt-oss" )
506
537
@@ -512,9 +543,13 @@ def main(args):
512
543
# Run some samples
513
544
torch .cuda .empty_cache ()
514
545
generated_ids_after_ptq = None
515
- if model_type != "llama4" :
546
+ if model_type != "llama4" and not is_nemotron_vl :
516
547
# Our fake quantizer may not be fully compatible with torch.compile.
517
548
generated_ids_after_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
549
+ elif is_nemotron_vl :
550
+ print (
551
+ "Skipping post-quantization generation for Nemotron VL model (requires image input)"
552
+ )
518
553
else :
519
554
warnings .warn (
520
555
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -577,9 +612,12 @@ def output_decode(generated_ids, input_shape):
577
612
# Save original model config and the processor config to the export path for VLMs.
578
613
print (f"Saving original model config to { export_path } " )
579
614
580
- AutoConfig .from_pretrained (
581
- args .pyt_ckpt_path , trust_remote_code = args .trust_remote_code
582
- ).save_pretrained (export_path )
615
+ config_kwargs = {"trust_remote_code" : args .trust_remote_code }
616
+ if args .attn_implementation is not None :
617
+ config_kwargs ["attn_implementation" ] = args .attn_implementation
618
+ AutoConfig .from_pretrained (args .pyt_ckpt_path , ** config_kwargs ).save_pretrained (
619
+ export_path
620
+ )
583
621
584
622
# Try to save processor config if available
585
623
try :
0 commit comments