Skip to content

Commit b6d49b4

Browse files
authored
Create multi_modal_infer.py
1 parent 3e39ed0 commit b6d49b4

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import sys
3+
import argparse
4+
from PIL import Image as PIL_Image
5+
import torch
6+
from transformers import MllamaForConditionalGeneration, MllamaProcessor
7+
8+
# Constants
9+
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
10+
11+
def load_model_and_processor(model_name: str):
12+
"""
13+
Load the model and processor based on the 11B or 90B model.
14+
"""
15+
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
16+
processor = MllamaProcessor.from_pretrained(model_name)
17+
return model, processor
18+
19+
def process_image(image_path: str) -> PIL_Image.Image:
20+
"""
21+
Open and convert an image from the specified path.
22+
"""
23+
if not os.path.exists(image_path):
24+
print(f"The image file '{image_path}' does not exist.")
25+
sys.exit(1)
26+
with open(image_path, "rb") as f:
27+
return PIL_Image.open(f).convert("RGB")
28+
29+
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
30+
"""
31+
Generate text from an image using the model and processor.
32+
"""
33+
conversation = [
34+
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
35+
]
36+
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
37+
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
38+
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
39+
return processor.decode(output[0])[len(prompt):]
40+
41+
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str):
42+
"""
43+
Call all the functions.
44+
"""
45+
model, processor = load_model_and_processor(model_name)
46+
image = process_image(image_path)
47+
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
48+
print("Generated Text: " + result)
49+
50+
if __name__ == "__main__":
51+
parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
52+
parser.add_argument("image_path", type=str, help="Path to the image file")
53+
parser.add_argument("prompt_text", type=str, help="Prompt text to describe the image")
54+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
55+
parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
56+
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
57+
58+
args = parser.parse_args()
59+
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name)

0 commit comments

Comments
 (0)