Skip to content

Commit 24a4dfd

Browse files
committed
special handling for nemotron VL preview generation in hf_ptq
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent f4134e3 commit 24a4dfd

File tree

1 file changed

+208
-23
lines changed

1 file changed

+208
-23
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 208 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from transformers import (
3535
AutoConfig,
36+
AutoImageProcessor,
3637
AutoModelForCausalLM,
3738
AutoProcessor,
3839
PreTrainedTokenizer,
@@ -91,6 +92,86 @@
9192
mto.enable_huggingface_checkpointing()
9293

9394

95+
def _run_vl_preview_generation(model, tokenizer, model_path, stage_name):
96+
"""Run preview generation for VL models using sample images.
97+
98+
Args:
99+
model: The VL model
100+
tokenizer: The tokenizer
101+
model_path: Path to the model (for loading image processor)
102+
stage_name: Description of the stage (e.g., "before quantization")
103+
104+
Returns:
105+
Generated response text for logging/comparison
106+
"""
107+
import os
108+
109+
from PIL import Image
110+
from transformers import AutoImageProcessor
111+
112+
try:
113+
print(f"Loading sample images for {stage_name} preview...")
114+
115+
# Load image processor
116+
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
117+
118+
# Load sample images from the images directory
119+
script_dir = os.path.dirname(os.path.abspath(__file__))
120+
images_dir = os.path.join(script_dir, "images")
121+
122+
image_files = ["example1a.jpeg", "example1b.jpeg"]
123+
images = []
124+
for img_file in image_files:
125+
img_path = os.path.join(images_dir, img_file)
126+
if os.path.exists(img_path):
127+
images.append(Image.open(img_path))
128+
print(f" Loaded: {img_file}")
129+
else:
130+
print(f" Warning: {img_file} not found")
131+
132+
if not images:
133+
print("No sample images found - skipping VL preview generation")
134+
return None
135+
136+
# Process images
137+
image_features = image_processor(images)
138+
139+
# Move image features to the same device as the model
140+
model_device = model.device
141+
for key, value in image_features.items():
142+
if hasattr(value, "to"): # Check if it's a tensor
143+
image_features[key] = value.to(model_device)
144+
print(f" Moved {key} to {model_device}")
145+
146+
# Generate response
147+
question = "Describe these images briefly."
148+
generation_config = {
149+
"max_new_tokens": 50,
150+
"do_sample": False,
151+
"eos_token_id": tokenizer.eos_token_id,
152+
}
153+
154+
print(f"Generating VL response ({stage_name})...")
155+
response = model.chat(
156+
tokenizer=tokenizer,
157+
question=question,
158+
generation_config=generation_config,
159+
**image_features,
160+
)
161+
162+
print(f"✅ VL generation {stage_name} successful!")
163+
print(f"Question: {question}")
164+
print(f"Response: {response}")
165+
166+
# Return the response for comparison/logging
167+
return response
168+
169+
except Exception as e:
170+
print(f"❌ VL preview generation {stage_name} failed: {e}")
171+
print("This may indicate issues with the quantized model")
172+
return None
173+
174+
94175
def auto_quantize(
95176
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
96177
):
@@ -486,13 +567,45 @@ def main(args):
486567
"input_features" if model_type == "whisper" else "input_ids"
487568
][0:1]
488569

489-
# Skip preview generation for Nemotron VL models that require special handling
570+
# For Nemotron VL models, try text-only generation first, then VL generation as additional test
490571
is_nemotron_vl = (
491572
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
492573
)
493574
if is_nemotron_vl:
494-
print("Skipping preview generation for Nemotron VL model (requires image input)")
495-
generated_ids_before_ptq = None
575+
print("Running text-only preview generation for Nemotron VL model...")
576+
try:
577+
# Try text-only generation using model.chat with None for images
578+
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
579+
generation_config = {
580+
"max_new_tokens": 100,
581+
"do_sample": False,
582+
"eos_token_id": tokenizer.eos_token_id,
583+
}
584+
585+
# Use model.chat with None for images (text-only mode)
586+
text_response = full_model.chat(
587+
tokenizer, None, question, generation_config, history=None
588+
)
589+
generated_ids_before_ptq = text_response # Store text response
590+
print(f"✅ Text-only generation successful: {text_response[:100]}...")
591+
592+
except Exception as e:
593+
print(f"Text-only generation failed: {e}")
594+
print("Falling back to standard generate() method...")
595+
try:
596+
generated_ids_before_ptq = full_model.generate(
597+
input_ids, max_new_tokens=100
598+
)
599+
except Exception as e2:
600+
print(f"Standard generation also failed: {e2}")
601+
generated_ids_before_ptq = None
602+
603+
# Run additional VL test with images
604+
print("Running additional VL test with images...")
605+
_run_vl_preview_generation(
606+
full_model, tokenizer, args.pyt_ckpt_path, "before quantization (VL test)"
607+
)
608+
496609
else:
497610
try:
498611
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)
@@ -508,6 +621,11 @@ def main(args):
508621

509622
# quantize the model
510623
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
624+
625+
# For VL models, update full_model to use the quantized language model
626+
if is_nemotron_vl and hasattr(full_model, "language_model"):
627+
print("Updating full_model with quantized language_model...")
628+
full_model.language_model = model
511629
if args.verbose:
512630
mtq.print_quant_summary(model)
513631

@@ -518,9 +636,33 @@ def main(args):
518636
# Our fake quantizer may not be fully compatible with torch.compile.
519637
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
520638
elif is_nemotron_vl:
521-
print(
522-
"Skipping post-quantization generation for Nemotron VL model (requires image input)"
639+
print("Running text-only preview generation for quantized Nemotron VL model...")
640+
try:
641+
# Try text-only generation using model.chat with None for images
642+
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
643+
generation_config = {
644+
"max_new_tokens": 100,
645+
"do_sample": False,
646+
"eos_token_id": tokenizer.eos_token_id,
647+
}
648+
649+
# Use model.chat with None for images (text-only mode)
650+
text_response = full_model.chat(
651+
tokenizer, None, question, generation_config, history=None
652+
)
653+
generated_ids_after_ptq = text_response # Store text response
654+
print(f"✅ Text-only generation successful: {text_response[:100]}...")
655+
656+
except Exception as e:
657+
print(f"Text-only generation failed: {e}")
658+
generated_ids_after_ptq = None
659+
660+
# Run additional VL test with images
661+
print("Running additional VL test with images...")
662+
_run_vl_preview_generation(
663+
full_model, tokenizer, args.pyt_ckpt_path, "after quantization (VL test)"
523664
)
665+
524666
else:
525667
warnings.warn(
526668
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
@@ -553,15 +695,25 @@ def output_decode(generated_ids, input_shape):
553695

554696
if generated_ids_after_ptq is not None:
555697
print("--------")
556-
print(f"example test input: {input_decode(input_ids)}")
557-
print("--------")
558-
print(
559-
f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}"
560-
)
561-
print("--------")
562-
print(
563-
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}"
564-
)
698+
if is_nemotron_vl:
699+
# For Nemotron VL models, generated_ids are text strings from model.chat()
700+
print("Nemotron VL model text-only generation results:")
701+
print(f"Text response before quantization: {generated_ids_before_ptq}")
702+
print("--------")
703+
print(f"Text response after quantization: {generated_ids_after_ptq}")
704+
print("--------")
705+
print("Note: Additional VL tests with images were run separately above")
706+
else:
707+
# For regular LLMs, generated_ids are token tensors that need decoding
708+
print(f"example test input: {input_decode(input_ids)}")
709+
print("--------")
710+
print(
711+
f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}"
712+
)
713+
print("--------")
714+
print(
715+
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}"
716+
)
565717
else:
566718
warnings.warn("Skipping quantization: model is already quantized.")
567719

@@ -590,15 +742,48 @@ def output_decode(generated_ids, input_shape):
590742
export_path
591743
)
592744

593-
# Try to save processor config if available
594-
try:
595-
print(f"Saving processor config to {export_path}")
596-
AutoProcessor.from_pretrained(
597-
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
598-
).save_pretrained(export_path)
599-
except Exception as e:
600-
print(f"Warning: Could not save processor config: {e}")
601-
print("This is normal for some VLM architectures that don't use AutoProcessor")
745+
# Try to save processor config if available (skip for Nemotron VL models)
746+
if not is_nemotron_vl:
747+
try:
748+
print(f"Saving processor config to {export_path}")
749+
AutoProcessor.from_pretrained(
750+
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
751+
).save_pretrained(export_path)
752+
except Exception as e:
753+
print(f"Warning: Could not save processor config: {e}")
754+
print("This is normal for some VLM architectures that don't use AutoProcessor")
755+
else:
756+
print("Skipping AutoProcessor for Nemotron VL (uses separate AutoImageProcessor)")
757+
758+
# For Nemotron VL models, save image processor using proper HuggingFace APIs
759+
if is_nemotron_vl:
760+
import os
761+
import shutil
762+
763+
# Try to save image processor config using HuggingFace API
764+
try:
765+
print("Saving image processor config using AutoImageProcessor...")
766+
image_processor = AutoImageProcessor.from_pretrained(
767+
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
768+
)
769+
image_processor.save_pretrained(export_path)
770+
print(" ✅ Image processor config saved successfully")
771+
except Exception as e:
772+
print(f" Warning: Could not save image processor config: {e}")
773+
774+
# Manually copy image_processing.py as it contains custom code that save_pretrained doesn't handle
775+
print("Copying custom image processing implementation...")
776+
src_path = os.path.join(args.pyt_ckpt_path, "image_processing.py")
777+
dst_path = os.path.join(export_path, "image_processing.py")
778+
779+
if os.path.exists(src_path):
780+
try:
781+
shutil.copy2(src_path, dst_path)
782+
print(" ✅ Copied: image_processing.py")
783+
except Exception as copy_e:
784+
print(f" Warning: Could not copy image_processing.py: {copy_e}")
785+
else:
786+
print(" Warning: image_processing.py not found in source model")
602787

603788
if model_type == "mllama":
604789
full_model_config = model.config

0 commit comments

Comments
 (0)