Skip to content

Commit 0ccb1d6

Browse files
committed
Fix according to coderabbitai
Signed-off-by: James Shen <[email protected]>
1 parent cf8cdf0 commit 0ccb1d6

File tree

10 files changed

+203
-103
lines changed

10 files changed

+203
-103
lines changed

examples/quantization/ptq_generate_vlm.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,12 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
8080
"""
8181
model_str = str(model)
8282

83-
# DEBUG: Print full model structure to diagnose CI vs local differences
84-
if is_rank_0:
85-
console.print(f"\n{'=' * 80}")
86-
console.print("[yellow]DEBUG: Full model structure:[/yellow]")
87-
console.print(f"{'=' * 80}")
88-
console.print(model_str)
89-
console.print(f"{'=' * 80}\n")
90-
9183
# TE spec quantized layers (VLM models always use TE spec)
9284
te_spec_layers = [
9385
"QuantTERowParallelLinear",
9486
"QuantTELayerNormColumnParallelLinear",
9587
]
9688

97-
# DEBUG: Check each layer individually
98-
if is_rank_0:
99-
console.print("[yellow]DEBUG: Checking for quantized layers:[/yellow]")
100-
for layer in te_spec_layers:
101-
found = layer in model_str
102-
status = "[green]FOUND[/green]" if found else "[red]NOT FOUND[/red]"
103-
console.print(f" {layer}: {status}")
104-
10589
# Check if model has TE spec quantized layers
10690
has_te_spec = all(layer in model_str for layer in te_spec_layers)
10791

@@ -264,21 +248,22 @@ def main(
264248
default=DEFAULT_IMAGE_PATH,
265249
help="Path to the image file for VLM generation.",
266250
)
267-
parser.add_argument("--trust-remote-code", action="store_true", default=True, help="if trust_remote_code")
251+
parser.add_argument("--trust-remote-code", action="store_true", help="if trust_remote_code")
268252

269253
args = parser.parse_args()
270-
main(
271-
args.hf_model_id,
272-
args.tp,
273-
args.pp,
274-
args.ep,
275-
args.etp,
276-
args.megatron_load_path,
277-
args.prompts,
278-
args.osl,
279-
args.image_path,
280-
args.trust_remote_code,
281-
)
282-
283-
if torch.distributed.is_initialized():
284-
torch.distributed.destroy_process_group()
254+
try:
255+
main(
256+
args.hf_model_id,
257+
args.tp,
258+
args.pp,
259+
args.ep,
260+
args.etp,
261+
args.megatron_load_path,
262+
args.prompts,
263+
args.osl,
264+
args.image_path,
265+
args.trust_remote_code,
266+
)
267+
finally:
268+
if torch.distributed.is_initialized():
269+
torch.distributed.destroy_process_group()

examples/quantization/quantize_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import argparse
23+
import copy
2324

2425
import modelopt.torch.quantization as mtq
2526
from rich.console import Console
@@ -40,7 +41,9 @@
4041
}
4142

4243

43-
def get_modelopt_torch_quantization_config(export_quant_cfg, export_kv_cache_quant=False, weight_only=False):
44+
def get_modelopt_torch_quantization_config(
45+
export_quant_cfg: str, export_kv_cache_quant: bool = False, weight_only: bool = False
46+
) -> dict:
4447
"""Return a quantization config based on the specified configuration.
4548
4649
Args:
@@ -54,7 +57,8 @@ def get_modelopt_torch_quantization_config(export_quant_cfg, export_kv_cache_qua
5457
Raises:
5558
KeyError: If export_quant_cfg is not a valid configuration name.
5659
"""
57-
mtq_config = QUANT_CFG_CHOICES[export_quant_cfg]
60+
# Use deepcopy to avoid mutating the original config in QUANT_CFG_CHOICES
61+
mtq_config = copy.deepcopy(QUANT_CFG_CHOICES[export_quant_cfg])
5862

5963
fp8_config = {"enable": True, "num_bits": (4, 3), "axis": None}
6064
fp4_config = {

examples/quantization/quantize_vlm.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ def _hf_dataset_forward_loop_func(
209209
disable_tqdm=True,
210210
)
211211

212-
if force_all_expert_routing:
213-
for name, module in model.named_modules():
214-
if isinstance(module, TopKRouter):
215-
module.topk = module.config.moe_router_topk
212+
# Restore original topk after calibration is complete
213+
if force_all_expert_routing:
214+
for name, module in model.named_modules():
215+
if isinstance(module, TopKRouter):
216+
module.topk = module.config.moe_router_topk
216217

217218

218219
def _custom_prompt_forward_loop_func(
@@ -393,6 +394,10 @@ def ptq_forward_loop_func(model):
393394
if megatron_save_path is None:
394395
model_name = hf_model_id.replace("/", "_")
395396
megatron_save_path = f"./{model_name}_quantized_{export_quant_cfg}"
397+
if is_rank_0:
398+
console.print(
399+
f"[yellow]No --megatron-save-path specified. Using default path: {megatron_save_path}[/yellow]"
400+
)
396401

397402
if is_rank_0:
398403
console.print("[green]Testing model AFTER quantization...[/green]")
@@ -406,18 +411,9 @@ def ptq_forward_loop_func(model):
406411
_custom_prompt_forward_loop_func(unwrapped_model, processor, is_rank_0, prompts)
407412

408413
# Save quantized model in Megatron format
409-
if megatron_save_path:
410-
save_path = megatron_save_path
411-
else:
412-
# Create default save path using model name and quantization config
413-
model_name = hf_model_id.split("/")[-1]
414-
save_path = f"{model_name}_quantized_{export_quant_cfg}"
415-
if is_rank_0:
416-
console.print(f"[yellow]No --megatron-save-path specified. Using default path: {save_path}[/yellow]")
417-
418414
if is_rank_0:
419-
console.print(f"Saving quantized Megatron checkpoint in {save_path}...")
420-
bridge.save_megatron_model(megatron_model, save_path)
415+
console.print(f"Saving quantized Megatron checkpoint in {megatron_save_path}...")
416+
bridge.save_megatron_model(megatron_model, megatron_save_path)
421417

422418

423419
if __name__ == "__main__":
@@ -459,25 +455,26 @@ def ptq_forward_loop_func(model):
459455
"Useful for offline CI environments.",
460456
)
461457
args = parser.parse_args()
462-
main(
463-
args.hf_model_id,
464-
args.tp,
465-
args.pp,
466-
args.ep,
467-
args.etp,
468-
args.megatron_save_path,
469-
args.export_quant_cfg,
470-
args.calib_size,
471-
args.compress,
472-
args.weight_only,
473-
args.export_kv_cache_quant,
474-
args.force_all_expert_routing,
475-
args.trust_remote_code,
476-
args.prompts,
477-
args.skip_quantization,
478-
args.test_image_path,
479-
args.use_random_calib,
480-
)
481-
482-
if torch.distributed.is_initialized():
483-
torch.distributed.destroy_process_group()
458+
try:
459+
main(
460+
args.hf_model_id,
461+
args.tp,
462+
args.pp,
463+
args.ep,
464+
args.etp,
465+
args.megatron_save_path,
466+
args.export_quant_cfg,
467+
args.calib_size,
468+
args.compress,
469+
args.weight_only,
470+
args.export_kv_cache_quant,
471+
args.force_all_expert_routing,
472+
args.trust_remote_code,
473+
args.prompts,
474+
args.skip_quantization,
475+
args.test_image_path,
476+
args.use_random_calib,
477+
)
478+
finally:
479+
if torch.distributed.is_initialized():
480+
torch.distributed.destroy_process_group()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ override-dependencies = [
120120
transformer-engine = { git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "6a34b6574fa6c29d9d07fdcddf9812cbb1488878" }
121121
megatron-core = { path = "3rdparty/Megatron-LM/" }
122122
nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" }
123-
nvidia-modelopt = { git = "https://github.com/NVIDIA/TensorRT-Model-Optimizer.git", rev = "0a4f0a8b933121f7af080261a0a5a7717f2c5d49" }
123+
nvidia-modelopt = { git = "https://github.com/NVIDIA/TensorRT-Model-Optimizer.git", rev = "aafd3883942a564f1ac08a1e5f363abd61d383cf" }
124124
# mamba-ssm = { git = "https://github.com/yfw/mamba", branch = "general_stride_fix" }
125125

126126
[project.optional-dependencies]

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def forward(
189189
video_grid_thw: torch.Tensor = None,
190190
# cat set at dataset
191191
image_input_mask: torch.Tensor = None,
192-
inference_context=None,
193-
runtime_gather_output=None,
192+
inference_context: object | None = None,
193+
runtime_gather_output: bool | None = None,
194194
) -> torch.Tensor:
195195
"""Forward function of the Qwen3VL model.
196196
@@ -212,6 +212,7 @@ def forward(
212212
output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape
213213
[b, s, vocab_size].
214214
"""
215+
del inference_context, runtime_gather_output # Unused, kept for API compatibility
215216
assert pixel_values_videos is None and video_grid_thw is None, "not support video now"
216217
assert inference_params is None, "not support inference"
217218

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,11 @@ def get_rope_index(
111111
attention_mask = torch.ones_like(total_input_ids)
112112
# Handle multi-dimensional attention masks
113113
elif attention_mask.dim() > 2:
114-
# For causal mask, create a simple 2D mask [batch, seq]
115-
attention_mask = torch.ones_like(total_input_ids)
114+
# Collapse to [batch, seq] while preserving padding information
115+
attention_mask = attention_mask.any(dim=-1)
116+
if attention_mask.dim() == 3:
117+
attention_mask = attention_mask.squeeze(1)
118+
attention_mask = attention_mask.to(dtype=total_input_ids.dtype)
116119
position_ids = torch.ones(
117120
3,
118121
input_ids.shape[0],
@@ -192,10 +195,11 @@ def get_rope_index(
192195
if attention_mask is not None:
193196
# Handle multi-dimensional attention mask
194197
if attention_mask.dim() > 2:
195-
# For causal mask, create a simple 2D mask [batch, seq]
196-
attention_mask = torch.ones(
197-
(input_ids.shape[0], input_ids.shape[1]), dtype=torch.long, device=input_ids.device
198-
)
198+
# Collapse to [batch, seq] while preserving padding information
199+
attention_mask = attention_mask.any(dim=-1)
200+
if attention_mask.dim() == 3:
201+
attention_mask = attention_mask.squeeze(1)
202+
attention_mask = attention_mask.to(dtype=torch.long)
199203
position_ids = attention_mask.long().cumsum(-1) - 1
200204
position_ids.masked_fill_(attention_mask == 0, 1)
201205
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)

tests/functional_tests/quantization/models/qwen_vl/test_qwen3_vl_quantization_workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_qwen3_vl_quantization_and_generation(self, qwen3_vl_toy_model_path, tmp
326326
if quantize_result.returncode != 0:
327327
print(f"Quantization STDOUT: {quantize_result.stdout}")
328328
print(f"Quantization STDERR: {quantize_result.stderr}")
329-
assert False, f"Quantization step failed with return code {quantize_result.returncode}"
329+
pytest.fail(f"Quantization step failed with return code {quantize_result.returncode}")
330330

331331
# Verify quantization succeeded
332332
assert "Quantizing the model with fp8 configuration" in quantize_result.stdout, (
@@ -350,7 +350,7 @@ def test_qwen3_vl_quantization_and_generation(self, qwen3_vl_toy_model_path, tmp
350350
if generation_result.returncode != 0:
351351
print(f"Generation STDOUT: {generation_result.stdout}")
352352
print(f"Generation STDERR: {generation_result.stderr}")
353-
assert False, f"Generation step failed with return code {generation_result.returncode}"
353+
pytest.fail(f"Generation step failed with return code {generation_result.returncode}")
354354

355355
# Verify generation succeeded
356356
stdout_normalized = generation_result.stdout.replace("\n", "")
@@ -408,7 +408,7 @@ def test_qwen3_vl_quantization_and_generation_parallelism(
408408
if quantize_result.returncode != 0:
409409
print(f"Quantization STDOUT: {quantize_result.stdout}")
410410
print(f"Quantization STDERR: {quantize_result.stderr}")
411-
assert False, f"Quantization step for {test_name} failed with return code {quantize_result.returncode}"
411+
pytest.fail(f"Quantization step for {test_name} failed with return code {quantize_result.returncode}")
412412

413413
# Verify quantization succeeded with correct parallelism
414414
assert "Quantizing the model with fp8 configuration" in quantize_result.stdout, (
@@ -438,7 +438,7 @@ def test_qwen3_vl_quantization_and_generation_parallelism(
438438
if generation_result.returncode != 0:
439439
print(f"Generation STDOUT: {generation_result.stdout}")
440440
print(f"Generation STDERR: {generation_result.stderr}")
441-
assert False, f"Generation step for {test_name} failed with return code {generation_result.returncode}"
441+
pytest.fail(f"Generation step for {test_name} failed with return code {generation_result.returncode}")
442442

443443
# Verify generation succeeded with correct parallelism
444444
stdout_normalized = generation_result.stdout.replace("\n", "")

tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,65 @@ def test_get_rope_index_packed_seq_params_fallback_dense_mask(self):
144144

145145
assert torch.equal(position_ids, expected_positions)
146146
assert torch.equal(deltas, expected_deltas)
147+
148+
def test_get_rope_index_with_3d_attention_mask(self):
149+
"""Test get_rope_index with 3D attention mask (batch, seq, seq)."""
150+
batch_size, seq_len = 2, 8
151+
input_ids = torch.randint(0, 1000, (batch_size, seq_len))
152+
# Create a 3D causal attention mask [batch, seq, seq]
153+
attention_mask = torch.tril(torch.ones((batch_size, seq_len, seq_len)))
154+
155+
position_ids, deltas = get_rope_index(
156+
spatial_merge_size=2,
157+
image_token_id=151655,
158+
video_token_id=151656,
159+
vision_start_token_id=151652,
160+
input_ids=input_ids,
161+
attention_mask=attention_mask,
162+
)
163+
164+
assert position_ids.shape == (3, batch_size, seq_len)
165+
assert deltas.shape == (batch_size, 1)
166+
167+
def test_get_rope_index_with_4d_attention_mask(self):
168+
"""Test get_rope_index with 4D attention mask (batch, 1, seq, seq)."""
169+
batch_size, seq_len = 2, 8
170+
input_ids = torch.randint(0, 1000, (batch_size, seq_len))
171+
# Create a 4D attention mask [batch, 1, seq, seq] - singleton head dimension
172+
attention_mask = torch.tril(torch.ones((batch_size, 1, seq_len, seq_len)))
173+
174+
position_ids, deltas = get_rope_index(
175+
spatial_merge_size=2,
176+
image_token_id=151655,
177+
video_token_id=151656,
178+
vision_start_token_id=151652,
179+
input_ids=input_ids,
180+
attention_mask=attention_mask,
181+
)
182+
183+
assert position_ids.shape == (3, batch_size, seq_len)
184+
assert deltas.shape == (batch_size, 1)
185+
186+
def test_get_rope_index_with_3d_attention_mask_and_image(self):
187+
"""Test get_rope_index with 3D attention mask and image grid."""
188+
batch_size, seq_len = 1, 16
189+
input_ids = torch.randint(0, 1000, (batch_size, seq_len))
190+
# Insert vision tokens
191+
input_ids[0, 4] = 151652 # vision_start_token_id
192+
input_ids[0, 5] = 151655 # image_token_id
193+
image_grid_thw = torch.tensor([[1, 4, 4]]) # t=1, h=4, w=4
194+
# Create a 3D attention mask [batch, seq, seq]
195+
attention_mask = torch.tril(torch.ones((batch_size, seq_len, seq_len)))
196+
197+
position_ids, deltas = get_rope_index(
198+
spatial_merge_size=2,
199+
image_token_id=151655,
200+
video_token_id=151656,
201+
vision_start_token_id=151652,
202+
input_ids=input_ids,
203+
image_grid_thw=image_grid_thw,
204+
attention_mask=attention_mask,
205+
)
206+
207+
assert position_ids.shape == (3, batch_size, seq_len)
208+
assert deltas.shape == (batch_size, 1)

0 commit comments

Comments
 (0)