From 98772b9cb6392632ce5f04c6bb45ca28f3deb31e Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Wed, 17 Sep 2025 21:48:28 +0000 Subject: [PATCH 1/5] default attn_implementaion to eager to avoid issues Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 81f4b639..77831e80 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -758,10 +758,10 @@ def output_decode(generated_ids, input_shape): parser.add_argument( "--attn_implementation", help=( - "Specify the attention implementation to use." + "Specify the attention implementation to use. " "This arg will be passed to the HF model loading if specified." ), - default=None, + default="eager", type=str, ) From 32bdfa9451fe93cfd646428e37686747be24da4e Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 19 Sep 2025 00:13:32 +0000 Subject: [PATCH 2/5] add proper detection and handling for nemotron VL model in ptq examples Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 17 ++++++++ examples/llm_ptq/hf_ptq.py | 64 ++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 3ac167db..1493c0aa 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -136,6 +136,17 @@ def get_model( if device == "cpu": device_map = "cpu" + # Special handling for vision-language models that may have device mapping issues + # Check if this is a VL model by looking at the model path + is_vl_model = any( + vl_keyword in ckpt_path.lower() for vl_keyword in ["vl", "vision", "nemotron-nano-vl"] + ) + if is_vl_model: + print( + "Detected vision-language model. Disabling automatic device mapping to avoid device_map errors." + ) + device_map = None + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} if attn_implementation is not None: config_kwargs["attn_implementation"] = attn_implementation @@ -235,6 +246,12 @@ def get_model( **model_kwargs, ) model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 77831e80..5c5f2ac7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -281,6 +281,16 @@ def main(args): model_type = get_model_type(model) + # Special handling for Nemotron VL models that aren't detected by standard model type detection + # For HF export, we want to keep vision unquantized, so we treat it as a regular language model + # and only quantize the language components + if model_type != "mllama" and is_multimodal_model(model): + print( + f"Detected multimodal model: {type(model).__name__}. " + f"For HF export, will quantize language components only, keeping vision unquantized." + ) + # Keep as regular model type to use text-only calibration + device = model.device if hasattr(model, "model"): device = model.model.device @@ -487,20 +497,41 @@ def main(args): "Please set the default input_mode to InputMode.LANGUAGE before quantizing." ) + # For Nemotron VL models, disable quantization of vision components + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + if not model_is_already_quantized or calibration_only: # Only run single sample for preview input_ids = next(iter(calib_dataloader))[ "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except Exception as e: - print( - "Error during model generation. Please check if your transformers version is " - "compatible with the model." - ) - print(f"Error details: {e}") - raise + + # Skip preview generation for Nemotron VL models that require special handling + is_nemotron_vl = ( + "nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower() + ) + if is_nemotron_vl: + print("Skipping preview generation for Nemotron VL model (requires image input)") + generated_ids_before_ptq = None + else: + try: + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + except Exception as e: + print( + "Error during model generation. Please check if your transformers version is " + "compatible with the model." + ) + print(f"Error details: {e}") + raise if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -512,9 +543,13 @@ def main(args): # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4": + if model_type != "llama4" and not is_nemotron_vl: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) + elif is_nemotron_vl: + print( + "Skipping post-quantization generation for Nemotron VL model (requires image input)" + ) else: warnings.warn( "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." @@ -577,9 +612,12 @@ def output_decode(generated_ids, input_shape): # Save original model config and the processor config to the export path for VLMs. print(f"Saving original model config to {export_path}") - AutoConfig.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + config_kwargs = {"trust_remote_code": args.trust_remote_code} + if args.attn_implementation is not None: + config_kwargs["attn_implementation"] = args.attn_implementation + AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained( + export_path + ) # Try to save processor config if available try: From c71b66172ab756641add131666e6b49fc783b402 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 19 Sep 2025 00:18:01 +0000 Subject: [PATCH 3/5] create fake vl inputs in export for nemotron VL model Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/unified_export_hf.py | 75 ++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f514e660..c5a43834 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -73,6 +73,58 @@ SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] +def _create_fake_vl_inputs(model, fake_input_ids): + """Create fake vision-language model inputs for export process. + + Args: + model: The VL model + fake_input_ids: The fake text input IDs tensor + + Returns: + dict: Dictionary of fake inputs for the VL model + """ + device = fake_input_ids.device + batch_size = fake_input_ids.shape[0] + + # Create fake inputs based on common VL model patterns + fake_inputs = { + "input_ids": fake_input_ids, + "attention_mask": torch.ones_like(fake_input_ids), + } + + # Add vision-specific inputs based on model configuration + if hasattr(model.config, "vision_config"): + vision_config = model.config.vision_config + # Create fake pixel values based on vision config + if hasattr(vision_config, "image_size"): + image_size = vision_config.image_size + else: + image_size = 224 # Default size + + if hasattr(vision_config, "num_channels"): + num_channels = vision_config.num_channels + else: + num_channels = 3 # RGB default + + # Create fake pixel values + fake_inputs["pixel_values"] = torch.zeros( + [batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device + ) + + # Handle Nemotron-specific inputs + model_name = getattr(model, "name_or_path", "").lower() + if "nemotron" in model_name: + # Nemotron models may need specific image flags + fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device) + + # Some VL models need aspect ratio information + fake_inputs["aspect_ratio_ids"] = None + fake_inputs["aspect_ratio_mask"] = None + fake_inputs["cross_attention_mask"] = None + + return fake_inputs + + def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -131,6 +183,14 @@ def _output_hook(module, input, output): with torch.no_grad(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input + + # Check if this is a VL model that needs special input handling + is_vl_model = ( + hasattr(model.config, "vision_config") + or hasattr(model, "vision_model") + or "nemotron" in getattr(model, "name_or_path", "").lower() + ) + if model_type.startswith("whisper"): # For Whisper models, we need to pass a fake input with the specific sequence length from transformers import AutoFeatureExtractor @@ -139,6 +199,18 @@ def _output_hook(module, input, output): fake_input = torch.ones( [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) + elif is_vl_model: + # For VL models, create proper fake vision inputs + print("Detected VL model during export - creating fake vision inputs") + try: + # Try to create proper fake vision inputs for the VL model + fake_kwargs = _create_fake_vl_inputs(model, fake_input) + except Exception as e: + print(f"Failed to create fake VL inputs: {e}") + print("Skipping requantize_resmooth_fused_llm_layers for VL model") + for handle in handles: + handle.remove() + return # Run forward pass so that all modules sharing the same input are collected using forward hook. @@ -146,6 +218,9 @@ def _output_hook(module, input, output): if getattr(model.config, "is_encoder_decoder", False): # For encoder-decoder models, we need to pass both the encoder and decoder input ids model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model: + # For VL models, use the fake vision inputs + model(**fake_kwargs) else: model(fake_input) From f40501d7de11ebcd6ceb26f1fd90ebd2465028ff Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 19 Sep 2025 22:38:30 +0000 Subject: [PATCH 4/5] update fake inputs generation, initialize distributed for Nemotron models Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/unified_export_hf.py | 121 ++++++++++++++++++--- 1 file changed, 106 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c5a43834..2a21451f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -83,17 +83,26 @@ def _create_fake_vl_inputs(model, fake_input_ids): Returns: dict: Dictionary of fake inputs for the VL model """ + import inspect + device = fake_input_ids.device batch_size = fake_input_ids.shape[0] + # Get the model's forward method signature to see what parameters it accepts + forward_signature = inspect.signature(model.forward) + accepted_params = set(forward_signature.parameters.keys()) + # Create fake inputs based on common VL model patterns - fake_inputs = { - "input_ids": fake_input_ids, - "attention_mask": torch.ones_like(fake_input_ids), - } + fake_inputs = {} + + # Always include basic text inputs if accepted + if "input_ids" in accepted_params: + fake_inputs["input_ids"] = fake_input_ids + if "attention_mask" in accepted_params: + fake_inputs["attention_mask"] = torch.ones_like(fake_input_ids) - # Add vision-specific inputs based on model configuration - if hasattr(model.config, "vision_config"): + # Add vision-specific inputs based on model configuration and accepted parameters + if hasattr(model.config, "vision_config") and "pixel_values" in accepted_params: vision_config = model.config.vision_config # Create fake pixel values based on vision config if hasattr(vision_config, "image_size"): @@ -111,16 +120,34 @@ def _create_fake_vl_inputs(model, fake_input_ids): [batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device ) - # Handle Nemotron-specific inputs + # Handle Nemotron-specific inputs based on testing results model_name = getattr(model, "name_or_path", "").lower() if "nemotron" in model_name: - # Nemotron models may need specific image flags - fake_inputs["image_flags"] = torch.zeros([batch_size, 1], dtype=torch.long, device=device) + if "pixel_values" in accepted_params: + # Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512] + # This represents 14 image patches, each 512x512 pixels with 3 channels + num_patches = 14 + patch_size = 512 + num_channels = 3 + + # Override any previous pixel_values with the correct Nemotron format + # Use small random values instead of zeros to avoid NoneType issues + fake_inputs["pixel_values"] = ( + torch.randn( + [num_patches, num_channels, patch_size, patch_size], + dtype=torch.float32, + device=device, + ) + * 0.1 + ) # Small values to avoid extreme activations - # Some VL models need aspect ratio information - fake_inputs["aspect_ratio_ids"] = None - fake_inputs["aspect_ratio_mask"] = None - fake_inputs["cross_attention_mask"] = None + if "image_flags" in accepted_params: + # Based on testing, image_flags should have shape [14] (no batch dimension) + # to match the [14, 256, 4096] tensor it's used to mask + num_patches = 14 # From pixel_values shape [14, 3, 512, 512] + fake_inputs["image_flags"] = torch.zeros( + [num_patches], dtype=torch.long, device=device + ) # Shape [14] to match vision tensor dimensions return fake_inputs @@ -202,6 +229,31 @@ def _output_hook(module, input, output): elif is_vl_model: # For VL models, create proper fake vision inputs print("Detected VL model during export - creating fake vision inputs") + + # Pre-emptively initialize distributed for Nemotron models that require it + model_name = getattr(model, "name_or_path", "").lower() + if "nemotron" in model_name: + import os + + import torch.distributed as dist + + if not dist.is_available() or not dist.is_initialized(): + print("Pre-initializing distributed processing for Nemotron VL model") + # Set up minimal distributed environment + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + + if dist.is_available() and not dist.is_initialized(): + try: + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + rank=0, + world_size=1, + ) + except Exception as dist_e: + print(f"Failed to initialize distributed processing: {dist_e}") try: # Try to create proper fake vision inputs for the VL model fake_kwargs = _create_fake_vl_inputs(model, fake_input) @@ -219,8 +271,47 @@ def _output_hook(module, input, output): # For encoder-decoder models, we need to pass both the encoder and decoder input ids model(fake_input, decoder_input_ids=decoder_fake_input) elif is_vl_model: - # For VL models, use the fake vision inputs - model(**fake_kwargs) + # For VL models, try to run optimization on just the language model part + language_model = None + if hasattr(model, "language_model"): + language_model = model.language_model + print( + "Found language_model attribute - running optimization on language model only" + ) + elif hasattr(model, "model") and hasattr(model.model, "language_model"): + language_model = model.model.language_model + print( + "Found language_model in model.model - running optimization on language model only" + ) + + if language_model is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + try: + language_model(fake_input) + print("✅ Language model optimization completed successfully") + except Exception as e: + print(f"Language model optimization failed: {e}") + print("Continuing with export...") + else: + # Fallback: try full model with VL inputs + print("No separate language_model found - trying full VL model") + try: + model(**fake_kwargs) + print("✅ Full VL model optimization completed successfully") + except (ValueError, RuntimeError, AttributeError) as e: + if ( + "Default process group has not been initialized" in str(e) + or "must match the size of tensor" in str(e) + or "'bool' object has no attribute 'sum'" in str(e) + ): + print(f"VL model forward pass failed: {e}") + print("Skipping optimization for VL model - continuing with export") + else: + raise else: model(fake_input) From 281029eb21b480e98b0b467c3966883de032e020 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Sat, 20 Sep 2025 00:26:13 +0000 Subject: [PATCH 5/5] remove distributed prcessing setup and vision input generation since we process language model part only in export Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/unified_export_hf.py | 134 +-------------------- 1 file changed, 4 insertions(+), 130 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 2a21451f..5be3ec46 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -73,85 +73,6 @@ SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] -def _create_fake_vl_inputs(model, fake_input_ids): - """Create fake vision-language model inputs for export process. - - Args: - model: The VL model - fake_input_ids: The fake text input IDs tensor - - Returns: - dict: Dictionary of fake inputs for the VL model - """ - import inspect - - device = fake_input_ids.device - batch_size = fake_input_ids.shape[0] - - # Get the model's forward method signature to see what parameters it accepts - forward_signature = inspect.signature(model.forward) - accepted_params = set(forward_signature.parameters.keys()) - - # Create fake inputs based on common VL model patterns - fake_inputs = {} - - # Always include basic text inputs if accepted - if "input_ids" in accepted_params: - fake_inputs["input_ids"] = fake_input_ids - if "attention_mask" in accepted_params: - fake_inputs["attention_mask"] = torch.ones_like(fake_input_ids) - - # Add vision-specific inputs based on model configuration and accepted parameters - if hasattr(model.config, "vision_config") and "pixel_values" in accepted_params: - vision_config = model.config.vision_config - # Create fake pixel values based on vision config - if hasattr(vision_config, "image_size"): - image_size = vision_config.image_size - else: - image_size = 224 # Default size - - if hasattr(vision_config, "num_channels"): - num_channels = vision_config.num_channels - else: - num_channels = 3 # RGB default - - # Create fake pixel values - fake_inputs["pixel_values"] = torch.zeros( - [batch_size, num_channels, image_size, image_size], dtype=torch.float32, device=device - ) - - # Handle Nemotron-specific inputs based on testing results - model_name = getattr(model, "name_or_path", "").lower() - if "nemotron" in model_name: - if "pixel_values" in accepted_params: - # Based on testing, Nemotron expects pixel_values with shape [14, 3, 512, 512] - # This represents 14 image patches, each 512x512 pixels with 3 channels - num_patches = 14 - patch_size = 512 - num_channels = 3 - - # Override any previous pixel_values with the correct Nemotron format - # Use small random values instead of zeros to avoid NoneType issues - fake_inputs["pixel_values"] = ( - torch.randn( - [num_patches, num_channels, patch_size, patch_size], - dtype=torch.float32, - device=device, - ) - * 0.1 - ) # Small values to avoid extreme activations - - if "image_flags" in accepted_params: - # Based on testing, image_flags should have shape [14] (no batch dimension) - # to match the [14, 256, 4096] tensor it's used to mask - num_patches = 14 # From pixel_values shape [14, 3, 512, 512] - fake_inputs["image_flags"] = torch.zeros( - [num_patches], dtype=torch.long, device=device - ) # Shape [14] to match vision tensor dimensions - - return fake_inputs - - def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -227,42 +148,8 @@ def _output_hook(module, input, output): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) elif is_vl_model: - # For VL models, create proper fake vision inputs - print("Detected VL model during export - creating fake vision inputs") - - # Pre-emptively initialize distributed for Nemotron models that require it - model_name = getattr(model, "name_or_path", "").lower() - if "nemotron" in model_name: - import os - - import torch.distributed as dist - - if not dist.is_available() or not dist.is_initialized(): - print("Pre-initializing distributed processing for Nemotron VL model") - # Set up minimal distributed environment - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", "29500") - os.environ.setdefault("RANK", "0") - os.environ.setdefault("WORLD_SIZE", "1") - - if dist.is_available() and not dist.is_initialized(): - try: - dist.init_process_group( - backend="nccl" if torch.cuda.is_available() else "gloo", - rank=0, - world_size=1, - ) - except Exception as dist_e: - print(f"Failed to initialize distributed processing: {dist_e}") - try: - # Try to create proper fake vision inputs for the VL model - fake_kwargs = _create_fake_vl_inputs(model, fake_input) - except Exception as e: - print(f"Failed to create fake VL inputs: {e}") - print("Skipping requantize_resmooth_fused_llm_layers for VL model") - for handle in handles: - handle.remove() - return + # For VL models, run optimization on language model component only + print("Detected VL model during export - optimizing language model component") # Run forward pass so that all modules sharing the same input are collected using forward hook. @@ -297,21 +184,8 @@ def _output_hook(module, input, output): print(f"Language model optimization failed: {e}") print("Continuing with export...") else: - # Fallback: try full model with VL inputs - print("No separate language_model found - trying full VL model") - try: - model(**fake_kwargs) - print("✅ Full VL model optimization completed successfully") - except (ValueError, RuntimeError, AttributeError) as e: - if ( - "Default process group has not been initialized" in str(e) - or "must match the size of tensor" in str(e) - or "'bool' object has no attribute 'sum'" in str(e) - ): - print(f"VL model forward pass failed: {e}") - print("Skipping optimization for VL model - continuing with export") - else: - raise + print("Warning: No language_model found in VL model - skipping optimization") + print("This is unexpected for most VL models") else: model(fake_input)