Skip to content

Commit 659a2c2

Browse files
committed
add support for v2 model inference (.generate) with image inputs
1 parent 3d6a42f commit 659a2c2

File tree

1 file changed

+183
-36
lines changed

1 file changed

+183
-36
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 183 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -107,57 +107,119 @@ def _run_vl_preview_generation(model, tokenizer, model_path, stage_name):
107107
import os
108108

109109
from PIL import Image
110-
from transformers import AutoImageProcessor
110+
from transformers import AutoImageProcessor, AutoProcessor
111111

112112
try:
113113
print(f"Loading sample images for {stage_name} preview...")
114114

115-
# Load image processor
116-
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
117-
118115
# Load sample images from the images directory
119116
script_dir = os.path.dirname(os.path.abspath(__file__))
120117
images_dir = os.path.join(script_dir, "images")
121118

119+
# Use single image for VL preview to avoid shape mismatch issues
122120
image_files = ["example1a.jpeg", "example1b.jpeg"]
123-
images = []
121+
image = None
124122
for img_file in image_files:
125123
img_path = os.path.join(images_dir, img_file)
126124
if os.path.exists(img_path):
127-
images.append(Image.open(img_path))
125+
image = Image.open(img_path)
128126
print(f" Loaded: {img_file}")
127+
break # Use the first available image
129128
else:
130129
print(f" Warning: {img_file} not found")
131130

132-
if not images:
131+
if image is None:
133132
print("No sample images found - skipping VL preview generation")
134133
return None
135134

136-
# Process images
137-
image_features = image_processor(images)
138-
139-
# Move image features to the same device as the model
140-
model_device = model.device
141-
for key, value in image_features.items():
142-
if hasattr(value, "to"): # Check if it's a tensor
143-
image_features[key] = value.to(model_device)
144-
print(f" Moved {key} to {model_device}")
145-
146135
# Generate response
147-
question = "Describe these images briefly."
136+
question = "Describe this image briefly." # Updated for single image
148137
generation_config = {
149138
"max_new_tokens": 50,
150139
"do_sample": False,
151140
"eos_token_id": tokenizer.eos_token_id,
152141
}
153142

154143
print(f"Generating VL response ({stage_name})...")
155-
response = model.chat(
156-
tokenizer=tokenizer,
157-
question=question,
158-
generation_config=generation_config,
159-
**image_features,
160-
)
144+
145+
# Try to detect if this is a v1 model (has chat method) or v2 model (uses generate)
146+
if hasattr(model, "chat"):
147+
print(" Using v1 model.chat() method...")
148+
# Load image processor for v1 models
149+
image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
150+
151+
# Process single image for v1 models
152+
image_features = image_processor([image]) # Pass as list with single image
153+
154+
# Move image features to the same device as the model
155+
model_device = model.device
156+
for key, value in image_features.items():
157+
if hasattr(value, "to"): # Check if it's a tensor
158+
image_features[key] = value.to(model_device)
159+
print(f" Moved {key} to {model_device}")
160+
161+
response = model.chat(
162+
tokenizer=tokenizer,
163+
question=question,
164+
generation_config=generation_config,
165+
**image_features,
166+
)
167+
else:
168+
print(" Using v2 model.generate() method...")
169+
# Load processor for v2 models
170+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
171+
172+
# Create messages in the format expected by v2 models
173+
messages = [
174+
{"role": "system", "content": "/no_think"},
175+
{
176+
"role": "user",
177+
"content": [
178+
{
179+
"type": "image",
180+
"image": "",
181+
},
182+
{
183+
"type": "text",
184+
"text": question,
185+
},
186+
],
187+
},
188+
]
189+
190+
# Apply chat template
191+
prompt = tokenizer.apply_chat_template(
192+
messages, tokenize=False, add_generation_prompt=True
193+
)
194+
195+
# Process inputs using the processor with single image
196+
inputs = processor(
197+
text=[prompt],
198+
images=[image], # Pass single image as list
199+
return_tensors="pt",
200+
)
201+
202+
# Move inputs to the same device as the model
203+
model_device = model.device
204+
inputs = inputs.to(model_device)
205+
print(f" Moved inputs to {model_device}")
206+
207+
# Generate response using model.generate
208+
generated_ids = model.generate(
209+
pixel_values=inputs.pixel_values,
210+
input_ids=inputs.input_ids,
211+
attention_mask=inputs.attention_mask,
212+
**generation_config,
213+
)
214+
215+
# Decode the response (trim input tokens like in the working example)
216+
generated_ids_trimmed = [
217+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
218+
]
219+
output_text = processor.batch_decode(
220+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
221+
)
222+
response = output_text[0]
161223

162224
print(f"✅ VL generation {stage_name} successful!")
163225
print(f"Question: {question}")
@@ -172,6 +234,83 @@ def _run_vl_preview_generation(model, tokenizer, model_path, stage_name):
172234
return None
173235

174236

237+
def _run_text_only_generation(model, tokenizer, question, generation_config, model_path):
238+
"""Run text-only generation for VL models, supporting both v1 (chat) and v2 (generate) models.
239+
240+
Args:
241+
model: The VL model
242+
tokenizer: The tokenizer
243+
question: The text question to ask
244+
generation_config: Generation configuration
245+
model_path: Path to the model (for loading processor if needed)
246+
247+
Returns:
248+
Generated response text or None if failed
249+
"""
250+
try:
251+
if hasattr(model, "chat"):
252+
print(" Using v1 model.chat() method for text-only generation...")
253+
# Use model.chat with None for images (text-only mode)
254+
response = model.chat(tokenizer, None, question, generation_config, history=None)
255+
return response
256+
else:
257+
print(" Using v2 model.generate() method for text-only generation...")
258+
# Load processor for v2 models
259+
from transformers import AutoProcessor
260+
261+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
262+
263+
# Create text-only messages
264+
messages = [
265+
{"role": "system", "content": "/no_think"},
266+
{
267+
"role": "user",
268+
"content": [
269+
{
270+
"type": "text",
271+
"text": question,
272+
},
273+
],
274+
},
275+
]
276+
277+
# Apply chat template
278+
prompt = tokenizer.apply_chat_template(
279+
messages, tokenize=False, add_generation_prompt=True
280+
)
281+
282+
# Process text-only inputs
283+
inputs = processor(
284+
text=[prompt],
285+
images=None, # No images for text-only
286+
return_tensors="pt",
287+
)
288+
289+
# Move inputs to the same device as the model
290+
model_device = model.device
291+
inputs = inputs.to(model_device)
292+
293+
# Generate response using model.generate
294+
generated_ids = model.generate(
295+
input_ids=inputs.input_ids,
296+
attention_mask=inputs.attention_mask,
297+
**generation_config,
298+
)
299+
300+
# Decode the response (trim input tokens like in the working example)
301+
generated_ids_trimmed = [
302+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
303+
]
304+
output_text = processor.batch_decode(
305+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
306+
)
307+
return output_text[0]
308+
309+
except Exception as e:
310+
print(f"Text-only generation failed: {e}")
311+
return None
312+
313+
175314
def auto_quantize(
176315
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
177316
):
@@ -574,7 +713,7 @@ def main(args):
574713
if is_nemotron_vl:
575714
print("Running text-only preview generation for Nemotron VL model...")
576715
try:
577-
# Try text-only generation using model.chat with None for images
716+
# Try text-only generation using helper function that supports both v1 and v2
578717
if tokenizer is None:
579718
raise ValueError("Tokenizer is required for Nemotron VL text generation")
580719

@@ -585,12 +724,16 @@ def main(args):
585724
"eos_token_id": tokenizer.eos_token_id,
586725
}
587726

588-
# Use model.chat with None for images (text-only mode)
589-
text_response = full_model.chat(
590-
tokenizer, None, question, generation_config, history=None
727+
# Use helper function that supports both v1 and v2 models
728+
text_response = _run_text_only_generation(
729+
full_model, tokenizer, question, generation_config, args.pyt_ckpt_path
591730
)
592-
generated_ids_before_ptq = text_response # Store text response
593-
print(f"✅ Text-only generation successful: {text_response[:100]}...")
731+
732+
if text_response is not None:
733+
generated_ids_before_ptq = text_response # Store text response
734+
print(f"✅ Text-only generation successful: {text_response[:100]}...")
735+
else:
736+
raise Exception("Text-only generation returned None")
594737

595738
except Exception as e:
596739
print(f"Text-only generation failed: {e}")
@@ -641,7 +784,7 @@ def main(args):
641784
elif is_nemotron_vl:
642785
print("Running text-only preview generation for quantized Nemotron VL model...")
643786
try:
644-
# Try text-only generation using model.chat with None for images
787+
# Try text-only generation using helper function that supports both v1 and v2
645788
if tokenizer is None:
646789
raise ValueError("Tokenizer is required for Nemotron VL text generation")
647790

@@ -652,12 +795,16 @@ def main(args):
652795
"eos_token_id": tokenizer.eos_token_id,
653796
}
654797

655-
# Use model.chat with None for images (text-only mode)
656-
text_response = full_model.chat(
657-
tokenizer, None, question, generation_config, history=None
798+
# Use helper function that supports both v1 and v2 models
799+
text_response = _run_text_only_generation(
800+
full_model, tokenizer, question, generation_config, args.pyt_ckpt_path
658801
)
659-
generated_ids_after_ptq = text_response # Store text response
660-
print(f"✅ Text-only generation successful: {text_response[:100]}...")
802+
803+
if text_response is not None:
804+
generated_ids_after_ptq = text_response # Store text response
805+
print(f"✅ Text-only generation successful: {text_response[:100]}...")
806+
else:
807+
generated_ids_after_ptq = None
661808

662809
except Exception as e:
663810
print(f"Text-only generation failed: {e}")

0 commit comments

Comments
 (0)