@@ -275,6 +275,16 @@ def main(args):
275275
276276 model_type = get_model_type (model )
277277
278+ # Special handling for Nemotron VL models that aren't detected by standard model type detection
279+ # For HF export, we want to keep vision unquantized, so we treat it as a regular language model
280+ # and only quantize the language components
281+ if model_type != "mllama" and is_multimodal_model (model ):
282+ print (
283+ f"Detected multimodal model: { type (model ).__name__ } . "
284+ f"For HF export, will quantize language components only, keeping vision unquantized."
285+ )
286+ # Keep as regular model type to use text-only calibration
287+
278288 device = model .device
279289 if hasattr (model , "model" ):
280290 device = model .model .device
@@ -458,20 +468,41 @@ def main(args):
458468 KV_QUANT_CFG_CHOICES ,
459469 )
460470
471+ # For Nemotron VL models, disable quantization of vision components
472+ is_nemotron_vl = (
473+ "nemotron" in args .pyt_ckpt_path .lower () and "vl" in args .pyt_ckpt_path .lower ()
474+ )
475+ if is_nemotron_vl :
476+ print ("Disabling quantization for vision components in Nemotron VL model" )
477+ quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
478+ quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
479+ # Also disable radio model components specifically
480+ quant_cfg ["quant_cfg" ]["*radio*" ] = {"enable" : False }
481+ quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
482+
461483 if not model_is_already_quantized or calibration_only :
462484 # Only run single sample for preview
463485 input_ids = next (iter (calib_dataloader ))[
464486 "input_features" if model_type == "whisper" else "input_ids"
465487 ][0 :1 ]
466- try :
467- generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
468- except Exception as e :
469- print (
470- "Error during model generation. Please check if your transformers version is "
471- "compatible with the model."
472- )
473- print (f"Error details: { e } " )
474- raise
488+
489+ # Skip preview generation for Nemotron VL models that require special handling
490+ is_nemotron_vl = (
491+ "nemotron" in args .pyt_ckpt_path .lower () and "vl" in args .pyt_ckpt_path .lower ()
492+ )
493+ if is_nemotron_vl :
494+ print ("Skipping preview generation for Nemotron VL model (requires image input)" )
495+ generated_ids_before_ptq = None
496+ else :
497+ try :
498+ generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
499+ except Exception as e :
500+ print (
501+ "Error during model generation. Please check if your transformers version is "
502+ "compatible with the model."
503+ )
504+ print (f"Error details: { e } " )
505+ raise
475506 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
476507 print ("Applying nvfp4 quantization (MoE only) for gpt-oss" )
477508
@@ -483,9 +514,13 @@ def main(args):
483514 # Run some samples
484515 torch .cuda .empty_cache ()
485516 generated_ids_after_ptq = None
486- if model_type != "llama4" :
517+ if model_type != "llama4" and not is_nemotron_vl :
487518 # Our fake quantizer may not be fully compatible with torch.compile.
488519 generated_ids_after_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
520+ elif is_nemotron_vl :
521+ print (
522+ "Skipping post-quantization generation for Nemotron VL model (requires image input)"
523+ )
489524 else :
490525 warnings .warn (
491526 "Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -548,9 +583,12 @@ def output_decode(generated_ids, input_shape):
548583 # Save original model config and the processor config to the export path for VLMs.
549584 print (f"Saving original model config to { export_path } " )
550585
551- AutoConfig .from_pretrained (
552- args .pyt_ckpt_path , trust_remote_code = args .trust_remote_code
553- ).save_pretrained (export_path )
586+ config_kwargs = {"trust_remote_code" : args .trust_remote_code }
587+ if args .attn_implementation is not None :
588+ config_kwargs ["attn_implementation" ] = args .attn_implementation
589+ AutoConfig .from_pretrained (args .pyt_ckpt_path , ** config_kwargs ).save_pretrained (
590+ export_path
591+ )
554592
555593 # Try to save processor config if available
556594 try :
0 commit comments