Skip to content

Commit 5b3c01b

Browse files
committed
refactor and create a utils script for vlm
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 8397bd3 commit 5b3c01b

File tree

3 files changed

+251
-247
lines changed

3 files changed

+251
-247
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def get_model(
202202
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
203203

204204
# Special handling for vision-language models that may have device mapping issues
205-
# Check if this is a VL model by examining the config before loading the full model
206205
try:
207206
hf_config_check = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
208207
if _is_multimodal_config(hf_config_check):

examples/llm_ptq/hf_ptq.py

Lines changed: 18 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PreTrainedTokenizerFast,
4141
WhisperProcessor,
4242
)
43+
from vlm_utils import run_text_only_generation, run_vl_preview_generation
4344

4445
import modelopt.torch.opt as mto
4546
import modelopt.torch.quantization as mtq
@@ -92,225 +93,6 @@
9293
mto.enable_huggingface_checkpointing()
9394

9495

95-
def _run_vl_preview_generation(model, tokenizer, model_path, stage_name):
96-
"""Run preview generation for VL models using sample images.
97-
98-
Args:
99-
model: The VL model
100-
tokenizer: The tokenizer
101-
model_path: Path to the model (for loading image processor)
102-
stage_name: Description of the stage (e.g., "before quantization")
103-
104-
Returns:
105-
Generated response text for logging/comparison
106-
"""
107-
import os
108-
109-
from PIL import Image
110-
from transformers import AutoImageProcessor, AutoProcessor
111-
112-
try:
113-
print(f"Loading sample images for {stage_name} preview...")
114-
115-
# Load sample images from the images directory
116-
script_dir = os.path.dirname(os.path.abspath(__file__))
117-
images_dir = os.path.join(script_dir, "images")
118-
119-
# Use single image for VL preview to avoid shape mismatch issues
120-
image_files = ["example1a.jpeg", "example1b.jpeg"]
121-
image = None
122-
for img_file in image_files:
123-
img_path = os.path.join(images_dir, img_file)
124-
if os.path.exists(img_path):
125-
image = Image.open(img_path)
126-
print(f" Loaded: {img_file}")
127-
break # Use the first available image
128-
else:
129-
print(f" Warning: {img_file} not found")
130-
131-
if image is None:
132-
print("No sample images found - skipping VL preview generation")
133-
return None
134-
135-
# Generate response
136-
question = "Describe this image briefly." # Updated for single image
137-
generation_config = {
138-
"max_new_tokens": 50,
139-
"do_sample": False,
140-
"eos_token_id": tokenizer.eos_token_id,
141-
}
142-
143-
print(f"Generating VL response ({stage_name})...")
144-
145-
# Try to detect if this is a v1 model (has chat method) or v2 model (uses generate)
146-
if hasattr(model, "chat"):
147-
print(" Using v1 model.chat() method...")
148-
# Load image processor for v1 models
149-
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
150-
151-
# Process single image for v1 models
152-
image_features = image_processor([image]) # Pass as list with single image
153-
154-
# Move image features to the same device as the model
155-
model_device = model.device
156-
for key, value in image_features.items():
157-
if hasattr(value, "to"): # Check if it's a tensor
158-
image_features[key] = value.to(model_device)
159-
print(f" Moved {key} to {model_device}")
160-
161-
response = model.chat(
162-
tokenizer=tokenizer,
163-
question=question,
164-
generation_config=generation_config,
165-
**image_features,
166-
)
167-
else:
168-
print(" Using v2 model.generate() method...")
169-
# Load processor for v2 models
170-
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
171-
172-
# Create messages in the format expected by v2 models
173-
messages = [
174-
{"role": "system", "content": "/no_think"},
175-
{
176-
"role": "user",
177-
"content": [
178-
{
179-
"type": "image",
180-
"image": "",
181-
},
182-
{
183-
"type": "text",
184-
"text": question,
185-
},
186-
],
187-
},
188-
]
189-
190-
# Apply chat template
191-
prompt = tokenizer.apply_chat_template(
192-
messages, tokenize=False, add_generation_prompt=True
193-
)
194-
195-
# Process inputs using the processor with single image
196-
inputs = processor(
197-
text=[prompt],
198-
images=[image], # Pass single image as list
199-
return_tensors="pt",
200-
)
201-
202-
# Move inputs to the same device as the model
203-
model_device = model.device
204-
inputs = inputs.to(model_device)
205-
print(f" Moved inputs to {model_device}")
206-
207-
# Generate response using model.generate
208-
generated_ids = model.generate(
209-
pixel_values=inputs.pixel_values,
210-
input_ids=inputs.input_ids,
211-
attention_mask=inputs.attention_mask,
212-
**generation_config,
213-
)
214-
215-
# Decode the response (trim input tokens like in the working example)
216-
generated_ids_trimmed = [
217-
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
218-
]
219-
output_text = processor.batch_decode(
220-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
221-
)
222-
response = output_text[0]
223-
224-
print(f"✅ VL generation {stage_name} successful!")
225-
print(f"Question: {question}")
226-
print(f"Response: {response}")
227-
228-
# Return the response for comparison/logging
229-
return response
230-
231-
except Exception as e:
232-
print(f"❌ VL preview generation {stage_name} failed: {e}")
233-
print("This may indicate issues with the quantized model")
234-
return None
235-
236-
237-
def _run_text_only_generation(model, tokenizer, question, generation_config, model_path):
238-
"""Run text-only generation for VL models, supporting both v1 (chat) and v2 (generate) models.
239-
240-
Args:
241-
model: The VL model
242-
tokenizer: The tokenizer
243-
question: The text question to ask
244-
generation_config: Generation configuration
245-
model_path: Path to the model (for loading processor if needed)
246-
247-
Returns:
248-
Generated response text or None if failed
249-
"""
250-
try:
251-
if hasattr(model, "chat"):
252-
print(" Using v1 model.chat() method for text-only generation...")
253-
# Use model.chat with None for images (text-only mode)
254-
response = model.chat(tokenizer, None, question, generation_config, history=None)
255-
return response
256-
else:
257-
print(" Using v2 model.generate() method for text-only generation...")
258-
# Load processor for v2 models
259-
from transformers import AutoProcessor
260-
261-
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
262-
263-
# Create text-only messages
264-
messages = [
265-
{"role": "system", "content": "/no_think"},
266-
{
267-
"role": "user",
268-
"content": [
269-
{
270-
"type": "text",
271-
"text": question,
272-
},
273-
],
274-
},
275-
]
276-
277-
# Apply chat template
278-
prompt = tokenizer.apply_chat_template(
279-
messages, tokenize=False, add_generation_prompt=True
280-
)
281-
282-
# Process text-only inputs
283-
inputs = processor(
284-
text=[prompt],
285-
images=None, # No images for text-only
286-
return_tensors="pt",
287-
)
288-
289-
# Move inputs to the same device as the model
290-
model_device = model.device
291-
inputs = inputs.to(model_device)
292-
293-
# Generate response using model.generate
294-
generated_ids = model.generate(
295-
input_ids=inputs.input_ids,
296-
attention_mask=inputs.attention_mask,
297-
**generation_config,
298-
)
299-
300-
# Decode the response (trim input tokens like in the working example)
301-
generated_ids_trimmed = [
302-
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
303-
]
304-
output_text = processor.batch_decode(
305-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
306-
)
307-
return output_text[0]
308-
309-
except Exception as e:
310-
print(f"Text-only generation failed: {e}")
311-
return None
312-
313-
31496
def auto_quantize(
31597
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
31698
):
@@ -688,17 +470,17 @@ def main(args):
688470
KV_QUANT_CFG_CHOICES,
689471
)
690472

691-
# For Nemotron VL models, disable quantization of vision components
692-
is_nemotron_vl = (
693-
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
694-
)
695-
if is_nemotron_vl:
696-
print("Disabling quantization for vision components in Nemotron VL model")
697-
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
698-
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
699-
# Also disable radio model components specifically
700-
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
701-
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
473+
# For Nemotron VL models, disable quantization of vision components
474+
is_nemotron_vl = (
475+
"nemotron" in args.pyt_ckpt_path.lower() and "vl" in args.pyt_ckpt_path.lower()
476+
)
477+
if is_nemotron_vl:
478+
print("Disabling quantization for vision components in Nemotron VL model")
479+
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
480+
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
481+
# Also disable radio model components specifically
482+
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
483+
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
702484

703485
if not model_is_already_quantized or calibration_only:
704486
# Only run single sample for preview
@@ -725,7 +507,7 @@ def main(args):
725507
}
726508

727509
# Use helper function that supports both v1 and v2 models
728-
text_response = _run_text_only_generation(
510+
text_response = run_text_only_generation(
729511
full_model, tokenizer, question, generation_config, args.pyt_ckpt_path
730512
)
731513

@@ -748,7 +530,7 @@ def main(args):
748530

749531
# Run additional VL test with images
750532
print("Running additional VL test with images...")
751-
_run_vl_preview_generation(
533+
run_vl_preview_generation(
752534
full_model, tokenizer, args.pyt_ckpt_path, "before quantization (VL test)"
753535
)
754536

@@ -768,23 +550,13 @@ def main(args):
768550
# quantize the model
769551
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
770552

771-
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
772-
773-
774553
# For VL models, update full_model to use the quantized language model
775554
if is_nemotron_vl and hasattr(full_model, "language_model"):
776555
print("Updating full_model with quantized language_model...")
777556
full_model.language_model = model
778-
amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
779-
model_keys = full_model.load_state_dict(amax_state_dict, strict=False)
780-
print(f"Loaded amax_state_dict with keys: {model_keys}")
781-
# fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
782-
# print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
783-
mtq.print_quant_summary(full_model.language_model)
784-
print("Loaded additional state dict into full_model.")
557+
785558
if args.verbose:
786-
pass
787-
# mtq.print_quant_summary(model)
559+
mtq.print_quant_summary(model)
788560

789561
# Run some samples
790562
torch.cuda.empty_cache()
@@ -807,7 +579,7 @@ def main(args):
807579
}
808580

809581
# Use helper function that supports both v1 and v2 models
810-
text_response = _run_text_only_generation(
582+
text_response = run_text_only_generation(
811583
full_model, tokenizer, question, generation_config, args.pyt_ckpt_path
812584
)
813585

@@ -823,7 +595,7 @@ def main(args):
823595

824596
# Run additional VL test with images
825597
print("Running additional VL test with images...")
826-
_run_vl_preview_generation(
598+
run_vl_preview_generation(
827599
full_model, tokenizer, args.pyt_ckpt_path, "after quantization (VL test)"
828600
)
829601

0 commit comments

Comments
 (0)