Skip to content

Commit cf8cdf0

Browse files
committed
Support TE spec, Add generation function (load and generate)
Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent 33924aa commit cf8cdf0

File tree

12 files changed

+980
-86
lines changed

12 files changed

+980
-86
lines changed

.github/workflows/cicd-main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ jobs:
395395
- script: L2_Launch_models_qwen
396396
- script: L2_Launch_models_qwen_quantization
397397
- script: L2_Launch_models_qwen_vl
398+
- script: L2_Launch_models_qwen_vl_quantization
398399
- script: L2_Launch_recipes_llama_1b
399400
- script: L2_Launch_recipes_llama_3b
400401
- script: L2_Launch_recipes_llama_distill

examples/quantization/ptq_generate.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
5858
If someone accidentally breaks the quantization loading logic (e.g., in
5959
has_modelopt_state or build_and_load_model), this check will catch it.
6060
61-
We check for QuantRowParallelLinear and QuantColumnParallelLinear as these
62-
are present in all quantized model architectures (GPT, Llama, Qwen, Nemotron-H, etc).
61+
We check for quantized layer types that indicate successful quantization:
62+
- Local spec: QuantRowParallelLinear, QuantColumnParallelLinear
63+
- TE spec: QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear
6364
6465
Args:
6566
model: The unwrapped model to validate
@@ -68,25 +69,36 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
6869
Raises:
6970
RuntimeError: If the model doesn't contain expected quantized layers
7071
"""
71-
# Check for quantized layer types that are universal across all architectures
7272
model_str = str(model)
7373

74-
required_quant_layers = [
74+
# Local spec quantized layers
75+
local_spec_layers = [
7576
"QuantRowParallelLinear",
7677
"QuantColumnParallelLinear",
7778
]
7879

79-
missing_layers = [layer for layer in required_quant_layers if layer not in model_str]
80+
# TE spec quantized layers
81+
te_spec_layers = [
82+
"QuantTERowParallelLinear",
83+
"QuantTELayerNormColumnParallelLinear",
84+
]
85+
86+
# Check if model has local spec quantized layers
87+
has_local_spec = all(layer in model_str for layer in local_spec_layers)
88+
89+
# Check if model has TE spec quantized layers
90+
has_te_spec = all(layer in model_str for layer in te_spec_layers)
8091

81-
if missing_layers:
92+
if not has_local_spec and not has_te_spec:
8293
error_msg = (
8394
f"\n{'=' * 80}\n"
8495
f"QUANTIZATION VALIDATION FAILED!\n"
8596
f"{'=' * 80}\n"
8697
f"Expected quantized layers not found in the loaded model.\n"
8798
f"This indicates the quantized checkpoint was not loaded correctly.\n\n"
88-
f"Missing: {missing_layers}\n"
89-
f"Expected: {required_quant_layers}\n\n"
99+
f"Expected one of:\n"
100+
f" - Local spec: {local_spec_layers}\n"
101+
f" - TE spec: {te_spec_layers}\n\n"
90102
f"This is likely due to a bug in the checkpoint loading logic.\n"
91103
f"{'=' * 80}\n"
92104
)
@@ -95,9 +107,16 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
95107
raise RuntimeError(error_msg)
96108

97109
if is_rank_0:
98-
console.print(
99-
"[green]✓ Quantization validation passed: Found QuantRowParallelLinear and QuantColumnParallelLinear[/green]"
100-
)
110+
if has_te_spec:
111+
console.print(
112+
"[green]✓ Quantization validation passed: Found TE spec quantized layers "
113+
"(QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear)[/green]"
114+
)
115+
else:
116+
console.print(
117+
"[green]✓ Quantization validation passed: Found local spec quantized layers "
118+
"(QuantRowParallelLinear, QuantColumnParallelLinear)[/green]"
119+
)
101120

102121

103122
@torchrun_main
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This example demonstrates how to load a quantized Megatron-LM VLM checkpoint
17+
and perform image+text generation using the AutoBridge on multiple GPUs.
18+
19+
Prerequisites:
20+
First, you must run the quantization process to create a quantized checkpoint:
21+
torchrun --nproc_per_node 8 examples/quantization/quantize_vlm.py \
22+
--hf-model-id Qwen/Qwen3-VL-8B-Instruct \
23+
--export-quant-cfg fp8 \
24+
--megatron-save-path ./qwen3_vl_quantized \
25+
--tp 8
26+
27+
The process is as follows:
28+
1. An AutoBridge is initialized from a pretrained Hugging Face VLM model
29+
to get the processor and model structure.
30+
2. The quantized Megatron-LM model is loaded from the checkpoint using the specified path.
31+
3. Image+text generation is performed using the loaded quantized model.
32+
33+
Usage:
34+
torchrun --nproc_per_node 8 examples/quantization/ptq_generate_vlm.py \
35+
--hf-model-id Qwen/Qwen3-VL-8B-Instruct \
36+
--megatron-load-path ./qwen3_vl_quantized \
37+
--tp 8 \
38+
--image-path /path/to/image.jpg \
39+
--prompts "Describe this image."
40+
"""
41+
42+
import argparse
43+
import os
44+
import sys
45+
import warnings
46+
47+
import torch
48+
from megatron.core.utils import unwrap_model
49+
from quantize_utils import console
50+
from quantize_vlm import _custom_prompt_forward_loop_func
51+
from transformers import AutoProcessor
52+
53+
from megatron.bridge import AutoBridge
54+
from megatron.bridge.models.decorators import torchrun_main
55+
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
56+
57+
58+
warnings.filterwarnings("ignore")
59+
60+
HF_MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct"
61+
DEFAULT_IMAGE_PATH = "/models/demo.jpeg"
62+
63+
64+
def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
65+
"""Validate that the model contains quantized layers.
66+
67+
This is a functional test to ensure quantized checkpoints are loaded correctly.
68+
If someone accidentally breaks the quantization loading logic (e.g., in
69+
has_modelopt_state or build_and_load_model), this check will catch it.
70+
71+
For VLM models, we only check for TE spec quantized layers since all supported
72+
VLM models (Qwen3-VL) use TE spec.
73+
74+
Args:
75+
model: The unwrapped model to validate
76+
is_rank_0: Whether this is rank 0 (for printing)
77+
78+
Raises:
79+
RuntimeError: If the model doesn't contain expected quantized layers
80+
"""
81+
model_str = str(model)
82+
83+
# DEBUG: Print full model structure to diagnose CI vs local differences
84+
if is_rank_0:
85+
console.print(f"\n{'=' * 80}")
86+
console.print("[yellow]DEBUG: Full model structure:[/yellow]")
87+
console.print(f"{'=' * 80}")
88+
console.print(model_str)
89+
console.print(f"{'=' * 80}\n")
90+
91+
# TE spec quantized layers (VLM models always use TE spec)
92+
te_spec_layers = [
93+
"QuantTERowParallelLinear",
94+
"QuantTELayerNormColumnParallelLinear",
95+
]
96+
97+
# DEBUG: Check each layer individually
98+
if is_rank_0:
99+
console.print("[yellow]DEBUG: Checking for quantized layers:[/yellow]")
100+
for layer in te_spec_layers:
101+
found = layer in model_str
102+
status = "[green]FOUND[/green]" if found else "[red]NOT FOUND[/red]"
103+
console.print(f" {layer}: {status}")
104+
105+
# Check if model has TE spec quantized layers
106+
has_te_spec = all(layer in model_str for layer in te_spec_layers)
107+
108+
if not has_te_spec:
109+
error_msg = (
110+
f"\n{'=' * 80}\n"
111+
f"QUANTIZATION VALIDATION FAILED!\n"
112+
f"{'=' * 80}\n"
113+
f"Expected quantized layers not found in the loaded model.\n"
114+
f"This indicates the quantized checkpoint was not loaded correctly.\n\n"
115+
f"Expected TE spec layers: {te_spec_layers}\n\n"
116+
f"This is likely due to a bug in the checkpoint loading logic.\n"
117+
f"{'=' * 80}\n"
118+
)
119+
if is_rank_0:
120+
console.print(f"[red]{error_msg}[/red]")
121+
raise RuntimeError(error_msg)
122+
123+
if is_rank_0:
124+
console.print(
125+
"[green]✓ Quantization validation passed: Found TE spec quantized layers "
126+
"(QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear)[/green]"
127+
)
128+
129+
130+
@torchrun_main
131+
def main(
132+
hf_model_id: str = HF_MODEL_ID,
133+
tp: int = 1,
134+
pp: int = 1,
135+
ep: int = 1,
136+
etp: int = 1,
137+
megatron_load_path: str = "./quantized_megatron_checkpoint",
138+
prompts: str = "Describe this image.",
139+
osl: int = 32,
140+
image_path: str = DEFAULT_IMAGE_PATH,
141+
trust_remote_code: bool = True,
142+
) -> None:
143+
"""Load a quantized Megatron-LM VLM checkpoint and perform image+text generation on multiple GPUs."""
144+
if os.environ.get("WORLD_SIZE") is None:
145+
console.print("This script must be launched with torchrun. Please run:")
146+
console.print(f"torchrun --nproc_per_node <gpus> {sys.argv[0]}")
147+
sys.exit(1)
148+
149+
# Check if the checkpoint path exists
150+
if not os.path.exists(megatron_load_path):
151+
console.print(f"[red]Error: Quantized checkpoint path {megatron_load_path} does not exist![/red]")
152+
console.print("[yellow]Please run the quantization process first:[/yellow]")
153+
console.print(
154+
f"[yellow]torchrun --nproc_per_node {tp} examples/quantization/quantize_vlm.py "
155+
f"--hf-model-id {hf_model_id} --megatron-save-path {megatron_load_path} --tp {tp}[/yellow]"
156+
)
157+
sys.exit(1)
158+
159+
# Check if the image path exists (skip check for URLs)
160+
is_url = image_path.startswith("http://") or image_path.startswith("https://")
161+
if not is_url and not os.path.exists(image_path):
162+
console.print(f"[red]Error: Image path {image_path} does not exist![/red]")
163+
sys.exit(1)
164+
165+
# Initialize bridge from HF model to get processor and model structure
166+
bridge = AutoBridge.from_hf_pretrained(
167+
hf_model_id,
168+
trust_remote_code=is_safe_repo(
169+
trust_remote_code=trust_remote_code,
170+
hf_path=hf_model_id,
171+
),
172+
)
173+
174+
# Load processor for VLM
175+
processor = AutoProcessor.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code)
176+
177+
# Get model provider and configure for multi-GPU execution
178+
model_provider = bridge.to_megatron_provider(load_weights=False)
179+
model_provider.tensor_model_parallel_size = tp
180+
model_provider.pipeline_model_parallel_size = pp
181+
model_provider.expert_model_parallel_size = ep
182+
model_provider.expert_tensor_parallel_size = etp
183+
model_provider.pipeline_dtype = torch.bfloat16
184+
185+
# Once all overrides are set, finalize the model provider to ensure the post initialization logic is run
186+
model_provider.finalize()
187+
model_provider.initialize_model_parallel(seed=0)
188+
megatron_model = bridge.load_megatron_model(
189+
megatron_load_path,
190+
mp_overrides={
191+
"tensor_model_parallel_size": tp,
192+
"pipeline_model_parallel_size": pp,
193+
"expert_model_parallel_size": ep,
194+
"expert_tensor_parallel_size": etp,
195+
},
196+
wrap_with_ddp=False,
197+
)
198+
megatron_model = [m.cuda() for m in megatron_model]
199+
200+
# Now we can check for rank
201+
is_rank_0 = torch.distributed.get_rank() == 0
202+
203+
if is_rank_0:
204+
console.print(f"[green]Tensor parallel size: {model_provider.tensor_model_parallel_size}[/green]")
205+
console.print(f"[green]Pipeline parallel size: {model_provider.pipeline_model_parallel_size}[/green]")
206+
console.print(f"[green]Expert parallel size: {model_provider.expert_model_parallel_size}[/green]")
207+
console.print(f"[green]Expert tensor parallel size: {model_provider.expert_tensor_parallel_size}[/green]")
208+
console.print(f"[green]Loaded quantized model from: {megatron_load_path}[/green]")
209+
210+
# Get the unwrapped model for generation
211+
unwrapped_model = unwrap_model(megatron_model)[0]
212+
unwrapped_model.eval()
213+
214+
# Validate that the model has quantized layers
215+
_validate_quantized_model(unwrapped_model, is_rank_0)
216+
217+
# Test quantized model with custom prompts
218+
if is_rank_0:
219+
console.print(f"[green]Loaded Quantized Model:\n {unwrapped_model}[/green]")
220+
console.print("[green]Testing quantized VLM model with image and prompt...[/green]")
221+
222+
_custom_prompt_forward_loop_func(unwrapped_model, processor, is_rank_0, prompts, osl, test_image_path=image_path)
223+
224+
if is_rank_0:
225+
console.print("[green]Generation completed successfully![/green]")
226+
227+
228+
if __name__ == "__main__":
229+
parser = argparse.ArgumentParser(
230+
description="Load a quantized Megatron-LM VLM checkpoint and perform image+text generation on multiple GPUs"
231+
)
232+
parser.add_argument(
233+
"--hf-model-id",
234+
type=str,
235+
default=HF_MODEL_ID,
236+
help="HuggingFace model ID for processor and model structure (e.g., Qwen/Qwen3-VL-8B-Instruct)",
237+
)
238+
239+
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
240+
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallelism size")
241+
parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size")
242+
parser.add_argument("--etp", type=int, default=1, help="Expert tensor parallelism size")
243+
parser.add_argument(
244+
"--megatron-load-path",
245+
type=str,
246+
default="./quantized_megatron_checkpoint",
247+
help="Path to the quantized Megatron checkpoint to load (must be created first using quantize_vlm.py)",
248+
)
249+
parser.add_argument(
250+
"--prompts",
251+
type=str,
252+
default="Describe this image.",
253+
help="Text prompt for testing quantized VLM model.",
254+
)
255+
parser.add_argument(
256+
"--osl",
257+
type=int,
258+
default=32,
259+
help="Output sequence length for generation.",
260+
)
261+
parser.add_argument(
262+
"--image-path",
263+
type=str,
264+
default=DEFAULT_IMAGE_PATH,
265+
help="Path to the image file for VLM generation.",
266+
)
267+
parser.add_argument("--trust-remote-code", action="store_true", default=True, help="if trust_remote_code")
268+
269+
args = parser.parse_args()
270+
main(
271+
args.hf_model_id,
272+
args.tp,
273+
args.pp,
274+
args.ep,
275+
args.etp,
276+
args.megatron_load_path,
277+
args.prompts,
278+
args.osl,
279+
args.image_path,
280+
args.trust_remote_code,
281+
)
282+
283+
if torch.distributed.is_initialized():
284+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)