|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +This example demonstrates how to load a quantized Megatron-LM VLM checkpoint |
| 17 | +and perform image+text generation using the AutoBridge on multiple GPUs. |
| 18 | +
|
| 19 | +Prerequisites: |
| 20 | +First, you must run the quantization process to create a quantized checkpoint: |
| 21 | + torchrun --nproc_per_node 8 examples/quantization/quantize_vlm.py \ |
| 22 | + --hf-model-id Qwen/Qwen3-VL-8B-Instruct \ |
| 23 | + --export-quant-cfg fp8 \ |
| 24 | + --megatron-save-path ./qwen3_vl_quantized \ |
| 25 | + --tp 8 |
| 26 | +
|
| 27 | +The process is as follows: |
| 28 | +1. An AutoBridge is initialized from a pretrained Hugging Face VLM model |
| 29 | + to get the processor and model structure. |
| 30 | +2. The quantized Megatron-LM model is loaded from the checkpoint using the specified path. |
| 31 | +3. Image+text generation is performed using the loaded quantized model. |
| 32 | +
|
| 33 | +Usage: |
| 34 | +torchrun --nproc_per_node 8 examples/quantization/ptq_generate_vlm.py \ |
| 35 | + --hf-model-id Qwen/Qwen3-VL-8B-Instruct \ |
| 36 | + --megatron-load-path ./qwen3_vl_quantized \ |
| 37 | + --tp 8 \ |
| 38 | + --image-path /path/to/image.jpg \ |
| 39 | + --prompts "Describe this image." |
| 40 | +""" |
| 41 | + |
| 42 | +import argparse |
| 43 | +import os |
| 44 | +import sys |
| 45 | +import warnings |
| 46 | + |
| 47 | +import torch |
| 48 | +from megatron.core.utils import unwrap_model |
| 49 | +from quantize_utils import console |
| 50 | +from quantize_vlm import _custom_prompt_forward_loop_func |
| 51 | +from transformers import AutoProcessor |
| 52 | + |
| 53 | +from megatron.bridge import AutoBridge |
| 54 | +from megatron.bridge.models.decorators import torchrun_main |
| 55 | +from megatron.bridge.models.hf_pretrained.utils import is_safe_repo |
| 56 | + |
| 57 | + |
| 58 | +warnings.filterwarnings("ignore") |
| 59 | + |
| 60 | +HF_MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct" |
| 61 | +DEFAULT_IMAGE_PATH = "/models/demo.jpeg" |
| 62 | + |
| 63 | + |
| 64 | +def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None: |
| 65 | + """Validate that the model contains quantized layers. |
| 66 | +
|
| 67 | + This is a functional test to ensure quantized checkpoints are loaded correctly. |
| 68 | + If someone accidentally breaks the quantization loading logic (e.g., in |
| 69 | + has_modelopt_state or build_and_load_model), this check will catch it. |
| 70 | +
|
| 71 | + For VLM models, we only check for TE spec quantized layers since all supported |
| 72 | + VLM models (Qwen3-VL) use TE spec. |
| 73 | +
|
| 74 | + Args: |
| 75 | + model: The unwrapped model to validate |
| 76 | + is_rank_0: Whether this is rank 0 (for printing) |
| 77 | +
|
| 78 | + Raises: |
| 79 | + RuntimeError: If the model doesn't contain expected quantized layers |
| 80 | + """ |
| 81 | + model_str = str(model) |
| 82 | + |
| 83 | + # DEBUG: Print full model structure to diagnose CI vs local differences |
| 84 | + if is_rank_0: |
| 85 | + console.print(f"\n{'=' * 80}") |
| 86 | + console.print("[yellow]DEBUG: Full model structure:[/yellow]") |
| 87 | + console.print(f"{'=' * 80}") |
| 88 | + console.print(model_str) |
| 89 | + console.print(f"{'=' * 80}\n") |
| 90 | + |
| 91 | + # TE spec quantized layers (VLM models always use TE spec) |
| 92 | + te_spec_layers = [ |
| 93 | + "QuantTERowParallelLinear", |
| 94 | + "QuantTELayerNormColumnParallelLinear", |
| 95 | + ] |
| 96 | + |
| 97 | + # DEBUG: Check each layer individually |
| 98 | + if is_rank_0: |
| 99 | + console.print("[yellow]DEBUG: Checking for quantized layers:[/yellow]") |
| 100 | + for layer in te_spec_layers: |
| 101 | + found = layer in model_str |
| 102 | + status = "[green]FOUND[/green]" if found else "[red]NOT FOUND[/red]" |
| 103 | + console.print(f" {layer}: {status}") |
| 104 | + |
| 105 | + # Check if model has TE spec quantized layers |
| 106 | + has_te_spec = all(layer in model_str for layer in te_spec_layers) |
| 107 | + |
| 108 | + if not has_te_spec: |
| 109 | + error_msg = ( |
| 110 | + f"\n{'=' * 80}\n" |
| 111 | + f"QUANTIZATION VALIDATION FAILED!\n" |
| 112 | + f"{'=' * 80}\n" |
| 113 | + f"Expected quantized layers not found in the loaded model.\n" |
| 114 | + f"This indicates the quantized checkpoint was not loaded correctly.\n\n" |
| 115 | + f"Expected TE spec layers: {te_spec_layers}\n\n" |
| 116 | + f"This is likely due to a bug in the checkpoint loading logic.\n" |
| 117 | + f"{'=' * 80}\n" |
| 118 | + ) |
| 119 | + if is_rank_0: |
| 120 | + console.print(f"[red]{error_msg}[/red]") |
| 121 | + raise RuntimeError(error_msg) |
| 122 | + |
| 123 | + if is_rank_0: |
| 124 | + console.print( |
| 125 | + "[green]✓ Quantization validation passed: Found TE spec quantized layers " |
| 126 | + "(QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear)[/green]" |
| 127 | + ) |
| 128 | + |
| 129 | + |
| 130 | +@torchrun_main |
| 131 | +def main( |
| 132 | + hf_model_id: str = HF_MODEL_ID, |
| 133 | + tp: int = 1, |
| 134 | + pp: int = 1, |
| 135 | + ep: int = 1, |
| 136 | + etp: int = 1, |
| 137 | + megatron_load_path: str = "./quantized_megatron_checkpoint", |
| 138 | + prompts: str = "Describe this image.", |
| 139 | + osl: int = 32, |
| 140 | + image_path: str = DEFAULT_IMAGE_PATH, |
| 141 | + trust_remote_code: bool = True, |
| 142 | +) -> None: |
| 143 | + """Load a quantized Megatron-LM VLM checkpoint and perform image+text generation on multiple GPUs.""" |
| 144 | + if os.environ.get("WORLD_SIZE") is None: |
| 145 | + console.print("This script must be launched with torchrun. Please run:") |
| 146 | + console.print(f"torchrun --nproc_per_node <gpus> {sys.argv[0]}") |
| 147 | + sys.exit(1) |
| 148 | + |
| 149 | + # Check if the checkpoint path exists |
| 150 | + if not os.path.exists(megatron_load_path): |
| 151 | + console.print(f"[red]Error: Quantized checkpoint path {megatron_load_path} does not exist![/red]") |
| 152 | + console.print("[yellow]Please run the quantization process first:[/yellow]") |
| 153 | + console.print( |
| 154 | + f"[yellow]torchrun --nproc_per_node {tp} examples/quantization/quantize_vlm.py " |
| 155 | + f"--hf-model-id {hf_model_id} --megatron-save-path {megatron_load_path} --tp {tp}[/yellow]" |
| 156 | + ) |
| 157 | + sys.exit(1) |
| 158 | + |
| 159 | + # Check if the image path exists (skip check for URLs) |
| 160 | + is_url = image_path.startswith("http://") or image_path.startswith("https://") |
| 161 | + if not is_url and not os.path.exists(image_path): |
| 162 | + console.print(f"[red]Error: Image path {image_path} does not exist![/red]") |
| 163 | + sys.exit(1) |
| 164 | + |
| 165 | + # Initialize bridge from HF model to get processor and model structure |
| 166 | + bridge = AutoBridge.from_hf_pretrained( |
| 167 | + hf_model_id, |
| 168 | + trust_remote_code=is_safe_repo( |
| 169 | + trust_remote_code=trust_remote_code, |
| 170 | + hf_path=hf_model_id, |
| 171 | + ), |
| 172 | + ) |
| 173 | + |
| 174 | + # Load processor for VLM |
| 175 | + processor = AutoProcessor.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code) |
| 176 | + |
| 177 | + # Get model provider and configure for multi-GPU execution |
| 178 | + model_provider = bridge.to_megatron_provider(load_weights=False) |
| 179 | + model_provider.tensor_model_parallel_size = tp |
| 180 | + model_provider.pipeline_model_parallel_size = pp |
| 181 | + model_provider.expert_model_parallel_size = ep |
| 182 | + model_provider.expert_tensor_parallel_size = etp |
| 183 | + model_provider.pipeline_dtype = torch.bfloat16 |
| 184 | + |
| 185 | + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run |
| 186 | + model_provider.finalize() |
| 187 | + model_provider.initialize_model_parallel(seed=0) |
| 188 | + megatron_model = bridge.load_megatron_model( |
| 189 | + megatron_load_path, |
| 190 | + mp_overrides={ |
| 191 | + "tensor_model_parallel_size": tp, |
| 192 | + "pipeline_model_parallel_size": pp, |
| 193 | + "expert_model_parallel_size": ep, |
| 194 | + "expert_tensor_parallel_size": etp, |
| 195 | + }, |
| 196 | + wrap_with_ddp=False, |
| 197 | + ) |
| 198 | + megatron_model = [m.cuda() for m in megatron_model] |
| 199 | + |
| 200 | + # Now we can check for rank |
| 201 | + is_rank_0 = torch.distributed.get_rank() == 0 |
| 202 | + |
| 203 | + if is_rank_0: |
| 204 | + console.print(f"[green]Tensor parallel size: {model_provider.tensor_model_parallel_size}[/green]") |
| 205 | + console.print(f"[green]Pipeline parallel size: {model_provider.pipeline_model_parallel_size}[/green]") |
| 206 | + console.print(f"[green]Expert parallel size: {model_provider.expert_model_parallel_size}[/green]") |
| 207 | + console.print(f"[green]Expert tensor parallel size: {model_provider.expert_tensor_parallel_size}[/green]") |
| 208 | + console.print(f"[green]Loaded quantized model from: {megatron_load_path}[/green]") |
| 209 | + |
| 210 | + # Get the unwrapped model for generation |
| 211 | + unwrapped_model = unwrap_model(megatron_model)[0] |
| 212 | + unwrapped_model.eval() |
| 213 | + |
| 214 | + # Validate that the model has quantized layers |
| 215 | + _validate_quantized_model(unwrapped_model, is_rank_0) |
| 216 | + |
| 217 | + # Test quantized model with custom prompts |
| 218 | + if is_rank_0: |
| 219 | + console.print(f"[green]Loaded Quantized Model:\n {unwrapped_model}[/green]") |
| 220 | + console.print("[green]Testing quantized VLM model with image and prompt...[/green]") |
| 221 | + |
| 222 | + _custom_prompt_forward_loop_func(unwrapped_model, processor, is_rank_0, prompts, osl, test_image_path=image_path) |
| 223 | + |
| 224 | + if is_rank_0: |
| 225 | + console.print("[green]Generation completed successfully![/green]") |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + parser = argparse.ArgumentParser( |
| 230 | + description="Load a quantized Megatron-LM VLM checkpoint and perform image+text generation on multiple GPUs" |
| 231 | + ) |
| 232 | + parser.add_argument( |
| 233 | + "--hf-model-id", |
| 234 | + type=str, |
| 235 | + default=HF_MODEL_ID, |
| 236 | + help="HuggingFace model ID for processor and model structure (e.g., Qwen/Qwen3-VL-8B-Instruct)", |
| 237 | + ) |
| 238 | + |
| 239 | + parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size") |
| 240 | + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallelism size") |
| 241 | + parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size") |
| 242 | + parser.add_argument("--etp", type=int, default=1, help="Expert tensor parallelism size") |
| 243 | + parser.add_argument( |
| 244 | + "--megatron-load-path", |
| 245 | + type=str, |
| 246 | + default="./quantized_megatron_checkpoint", |
| 247 | + help="Path to the quantized Megatron checkpoint to load (must be created first using quantize_vlm.py)", |
| 248 | + ) |
| 249 | + parser.add_argument( |
| 250 | + "--prompts", |
| 251 | + type=str, |
| 252 | + default="Describe this image.", |
| 253 | + help="Text prompt for testing quantized VLM model.", |
| 254 | + ) |
| 255 | + parser.add_argument( |
| 256 | + "--osl", |
| 257 | + type=int, |
| 258 | + default=32, |
| 259 | + help="Output sequence length for generation.", |
| 260 | + ) |
| 261 | + parser.add_argument( |
| 262 | + "--image-path", |
| 263 | + type=str, |
| 264 | + default=DEFAULT_IMAGE_PATH, |
| 265 | + help="Path to the image file for VLM generation.", |
| 266 | + ) |
| 267 | + parser.add_argument("--trust-remote-code", action="store_true", default=True, help="if trust_remote_code") |
| 268 | + |
| 269 | + args = parser.parse_args() |
| 270 | + main( |
| 271 | + args.hf_model_id, |
| 272 | + args.tp, |
| 273 | + args.pp, |
| 274 | + args.ep, |
| 275 | + args.etp, |
| 276 | + args.megatron_load_path, |
| 277 | + args.prompts, |
| 278 | + args.osl, |
| 279 | + args.image_path, |
| 280 | + args.trust_remote_code, |
| 281 | + ) |
| 282 | + |
| 283 | + if torch.distributed.is_initialized(): |
| 284 | + torch.distributed.destroy_process_group() |
0 commit comments