diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 3ac167db..1493c0aa 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -136,6 +136,17 @@ def get_model( if device == "cpu": device_map = "cpu" + # Special handling for vision-language models that may have device mapping issues + # Check if this is a VL model by looking at the model path + is_vl_model = any( + vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"] + ) + if is_vl_model: + print( + "Detected vision-language model. Disabling automatic device mapping to avoid device_map errors." + ) + device_map = None + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} if attn_implementation is not None: config_kwargs["attn_implementation"] = attn_implementation @@ -235,6 +246,12 @@ def get_model( **model_kwargs, ) model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 81f4b639..5c5f2ac7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -281,6 +281,16 @@ def main(args): model_type = get_model_type(model) + # Special handling for Nemotron VL models that aren't detected by standard model type detection + # For HF export, we want to keep vision unquantized, so we treat it as a regular language model + # and only quantize the language components + if model_type != "mllama" and is_multimodal_model(model): + print( + f"Detected multimodal model: {type(model).__name__}. " + f"For HF export, will quantize language components only, keeping vision unquantized." + ) + # Keep as regular model type to use text-only calibration + device = model.device if hasattr(model, "model"): device = model.model.device @@ -487,20 +497,41 @@ def main(args): "Please set the default input_mode to InputMode.LANGUAGE before quantizing." ) + # For Nemotron VL models, disable quantization of vision components + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + if not model_is_already_quantized or calibration_only: # Only run single sample for preview input_ids = next(iter(calib_dataloader))[ "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except Exception as e: - print( - "Error during model generation. Please check if your transformers version is " - "compatible with the model." - ) - print(f"Error details: {e}") - raise + + # Skip preview generation for Nemotron VL models that require special handling + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Skipping preview generation for Nemotron VL model (requires image input)") + generated_ids_before_ptq = None + else: + try: + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + except Exception as e: + print( + "Error during model generation. Please check if your transformers version is " + "compatible with the model." + ) + print(f"Error details: {e}") + raise if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -512,9 +543,13 @@ def main(args): # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4": + if model_type != "llama4" and not is_nemotron_vl: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) + elif is_nemotron_vl: + print( + "Skipping post-quantization generation for Nemotron VL model (requires image input)" + ) else: warnings.warn( "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." @@ -577,9 +612,12 @@ def output_decode(generated_ids, input_shape): # Save original model config and the processor config to the export path for VLMs. print(f"Saving original model config to {export_path}") - AutoConfig.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + config_kwargs = {"trust_remote_code": args.trust_remote_code} + if args.attn_implementation is not None: + config_kwargs["attn_implementation"] = args.attn_implementation + AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained( + export_path + ) # Try to save processor config if available try: @@ -758,10 +796,10 @@ def output_decode(generated_ids, input_shape): parser.add_argument( "--attn_implementation", help=( - "Specify the attention implementation to use." + "Specify the attention implementation to use. " "This arg will be passed to the HF model loading if specified." ), - default=None, + default="eager", type=str, ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f514e660..5be3ec46 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -131,6 +131,14 @@ def _output_hook(module, input, output): with torch.no_grad(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input + + # Check if this is a VL model that needs special input handling + is_vl_model = ( + hasattr(model.config, "vision_config") + or hasattr(model, "vision_model") + or "nemotron" in getattr(model, "name_or_path", "").lower() + ) + if model_type.startswith("whisper"): # For Whisper models, we need to pass a fake input with the specific sequence length from transformers import AutoFeatureExtractor @@ -139,6 +147,9 @@ def _output_hook(module, input, output): fake_input = torch.ones( [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) + elif is_vl_model: + # For VL models, run optimization on language model component only + print("Detected VL model during export - optimizing language model component") # Run forward pass so that all modules sharing the same input are collected using forward hook. @@ -146,6 +157,35 @@ def _output_hook(module, input, output): if getattr(model.config, "is_encoder_decoder", False): # For encoder-decoder models, we need to pass both the encoder and decoder input ids model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model: + # For VL models, try to run optimization on just the language model part + language_model = None + if hasattr(model, "language_model"): + language_model = model.language_model + print( + "Found language_model attribute - running optimization on language model only" + ) + elif hasattr(model, "model") and hasattr(model.model, "language_model"): + language_model = model.model.language_model + print( + "Found language_model in model.model - running optimization on language model only" + ) + + if language_model is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + try: + language_model(fake_input) + print("✅ Language model optimization completed successfully") + except Exception as e: + print(f"Language model optimization failed: {e}") + print("Continuing with export...") + else: + print("Warning: No language_model found in VL model - skipping optimization") + print("This is unexpected for most VL models") else: model(fake_input)