3030 get_processor ,
3131 get_tokenizer ,
3232 is_enc_dec ,
33+ is_nemotron_vl ,
34+ run_nemotron_vl_preview ,
3335)
3436from transformers import (
3537 AutoConfig ,
4850 export_tensorrt_llm_checkpoint ,
4951 get_model_type ,
5052)
51- from modelopt .torch .export .model_utils import is_multimodal_model
53+ from modelopt .torch .export .model_utils import get_language_model_from_vl , is_multimodal_model
5254from modelopt .torch .quantization .config import need_calibration
5355from modelopt .torch .quantization .plugins .accelerate import init_quantized_weights
5456from modelopt .torch .quantization .utils import is_quantized
@@ -283,6 +285,9 @@ def main(args):
283285
284286 full_model = model
285287
288+ # Detect if this is a Nemotron VL model using architecture-based detection
289+ is_nemotron_vl_model = is_nemotron_vl (full_model )
290+
286291 if model_type == "mllama" :
287292 processor = get_processor (
288293 args .pyt_ckpt_path ,
@@ -312,15 +317,8 @@ def main(args):
312317 tokenizer .padding_side = "left"
313318
314319 # We only quantize the language model for VLMs other than the type supported above.
315- if hasattr (model , "language_model" ):
316- parent_model = model # llama4 case
317- if isinstance (type (model ).__dict__ .get ("language_model" ), property ):
318- assert hasattr (model , "model" ) and hasattr (model .model , "language_model" ), (
319- "Expected language_model in model.model, but attribute not found. "
320- "This may indicate an unsupported model structure."
321- )
322- parent_model = model .model # gemma3, qwen2.5 VL case
323-
320+ language_model , parent_model = get_language_model_from_vl (model )
321+ if language_model is not None :
324322 disabled_quant_cfg = {
325323 "quant_cfg" : {"default" : {"enable" : False }},
326324 "algorithm" : "max" ,
@@ -331,7 +329,7 @@ def main(args):
331329 if name != "language_model" :
332330 mtq .quantize (child , disabled_quant_cfg , forward_loop = None )
333331
334- model = model . language_model
332+ model = language_model
335333 model_type = get_model_type (model )
336334
337335 if model_type == "phi4mm" :
@@ -458,34 +456,65 @@ def main(args):
458456 KV_QUANT_CFG_CHOICES ,
459457 )
460458
459+ # For Nemotron VL models, disable quantization of vision components
460+ if is_nemotron_vl_model :
461+ print ("Disabling quantization for vision components in Nemotron VL model" )
462+ quant_cfg ["quant_cfg" ]["*vision*" ] = {"enable" : False }
463+ quant_cfg ["quant_cfg" ]["*image*" ] = {"enable" : False }
464+ # Also disable radio model components specifically
465+ quant_cfg ["quant_cfg" ]["*radio*" ] = {"enable" : False }
466+ quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
467+
461468 if not model_is_already_quantized or calibration_only :
462469 # Only run single sample for preview
463470 input_ids = next (iter (calib_dataloader ))[
464471 "input_features" if model_type == "whisper" else "input_ids"
465472 ][0 :1 ]
466- try :
467- generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
468- except Exception as e :
469- print (
470- "Error during model generation. Please check if your transformers version is "
471- "compatible with the model."
473+
474+ # Generate preview before quantization
475+ if is_nemotron_vl_model and tokenizer is not None :
476+ generated_ids_before_ptq = run_nemotron_vl_preview (
477+ full_model ,
478+ tokenizer ,
479+ input_ids ,
480+ args .pyt_ckpt_path ,
481+ "before quantization" ,
482+ allow_fallback = True ,
472483 )
473- print (f"Error details: { e } " )
474- raise
484+ else :
485+ # Standard generation for non-Nemotron VL models
486+ generated_ids_before_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
475487 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
476488 print ("Applying nvfp4 quantization (MoE only) for gpt-oss" )
477489
478490 # quantize the model
479491 model = quantize_model (model , quant_cfg , args , calib_dataloader , calibration_only )
492+
493+ # For VL models, update full_model to use the quantized language model
494+ if is_nemotron_vl_model :
495+ _ , parent_model = get_language_model_from_vl (full_model )
496+ if parent_model is not None :
497+ print ("Updating full_model with quantized language_model..." )
498+ parent_model .language_model = model
499+
480500 if args .verbose :
481501 mtq .print_quant_summary (model )
482502
483503 # Run some samples
484504 torch .cuda .empty_cache ()
485505 generated_ids_after_ptq = None
486- if model_type != "llama4" :
506+ if model_type != "llama4" and not is_nemotron_vl_model :
487507 # Our fake quantizer may not be fully compatible with torch.compile.
488508 generated_ids_after_ptq = full_model .generate (input_ids , max_new_tokens = 100 )
509+ elif is_nemotron_vl_model and tokenizer is not None :
510+ generated_ids_after_ptq = run_nemotron_vl_preview (
511+ full_model ,
512+ tokenizer ,
513+ input_ids ,
514+ args .pyt_ckpt_path ,
515+ "after quantization" ,
516+ allow_fallback = False ,
517+ )
489518 else :
490519 warnings .warn (
491520 "Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -518,15 +547,25 @@ def output_decode(generated_ids, input_shape):
518547
519548 if generated_ids_after_ptq is not None :
520549 print ("--------" )
521- print (f"example test input: { input_decode (input_ids )} " )
522- print ("--------" )
523- print (
524- f"example outputs before ptq: { output_decode (generated_ids_before_ptq , input_ids .shape [1 ])} "
525- )
526- print ("--------" )
527- print (
528- f"example outputs after ptq: { output_decode (generated_ids_after_ptq , input_ids .shape [1 ])} "
529- )
550+ if is_nemotron_vl_model :
551+ # For Nemotron VL models, generated_ids are text strings from model.chat()
552+ print ("Nemotron VL model text-only generation results:" )
553+ print (f"Text response before quantization: { generated_ids_before_ptq } " )
554+ print ("--------" )
555+ print (f"Text response after quantization: { generated_ids_after_ptq } " )
556+ print ("--------" )
557+ print ("Note: Additional VL tests with images were run separately above" )
558+ else :
559+ # For regular LLMs, generated_ids are token tensors that need decoding
560+ print (f"example test input: { input_decode (input_ids )} " )
561+ print ("--------" )
562+ print (
563+ f"example outputs before ptq: { output_decode (generated_ids_before_ptq , input_ids .shape [1 ])} "
564+ )
565+ print ("--------" )
566+ print (
567+ f"example outputs after ptq: { output_decode (generated_ids_after_ptq , input_ids .shape [1 ])} "
568+ )
530569 else :
531570 warnings .warn ("Skipping quantization: model is already quantized." )
532571
@@ -548,9 +587,12 @@ def output_decode(generated_ids, input_shape):
548587 # Save original model config and the processor config to the export path for VLMs.
549588 print (f"Saving original model config to { export_path } " )
550589
551- AutoConfig .from_pretrained (
552- args .pyt_ckpt_path , trust_remote_code = args .trust_remote_code
553- ).save_pretrained (export_path )
590+ config_kwargs = {"trust_remote_code" : args .trust_remote_code }
591+ if args .attn_implementation is not None :
592+ config_kwargs ["attn_implementation" ] = args .attn_implementation
593+ AutoConfig .from_pretrained (args .pyt_ckpt_path , ** config_kwargs ).save_pretrained (
594+ export_path
595+ )
554596
555597 # Try to save processor config if available
556598 try :
@@ -748,7 +790,7 @@ def output_decode(generated_ids, input_shape):
748790 parser .add_argument (
749791 "--attn_implementation" ,
750792 help = (
751- "Specify the attention implementation to use."
793+ "Specify the attention implementation to use. "
752794 "This arg will be passed to the HF model loading if specified."
753795 ),
754796 default = None ,
0 commit comments