Skip to content

Commit e45b4c6

Browse files
author
Sanyam Bhutani
committed
final fixes
1 parent c587a7f commit e45b4c6

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

recipes/quickstart/inference/local_inference/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer
44

55
The way to run this would be
66
```
7-
python multi_modal_infer.py --image_path "../../../responsible_ai/resources/dog.jpg" --input_prompt "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
7+
python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
88
```
99

1010
For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.

recipes/quickstart/inference/local_inference/multi_modal_infer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import torch
66
from transformers import MllamaForConditionalGeneration, MllamaProcessor
77

8+
89
# Constants
910
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
1011

11-
def load_model_and_processor(model_name: str):
12+
13+
def load_model_and_processor(model_name: str, hf_token: str):
1214
"""
1315
Load the model and processor based on the 11B or 90B model.
1416
"""
15-
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
16-
processor = MllamaProcessor.from_pretrained(model_name)
17+
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
18+
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
1719
return model, processor
1820

21+
1922
def process_image(image_path: str) -> PIL_Image.Image:
2023
"""
2124
Open and convert an image from the specified path.
@@ -26,6 +29,7 @@ def process_image(image_path: str) -> PIL_Image.Image:
2629
with open(image_path, "rb") as f:
2730
return PIL_Image.open(f).convert("RGB")
2831

32+
2933
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
3034
"""
3135
Generate text from an image using the model and processor.
@@ -38,22 +42,25 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
3842
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
3943
return processor.decode(output[0])[len(prompt):]
4044

41-
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str):
45+
46+
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
4247
"""
4348
Call all the functions.
4449
"""
45-
model, processor = load_model_and_processor(model_name)
50+
model, processor = load_model_and_processor(model_name, hf_token)
4651
image = process_image(image_path)
4752
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
4853
print("Generated Text: " + result)
4954

55+
5056
if __name__ == "__main__":
5157
parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
52-
parser.add_argument("image_path", type=str, help="Path to the image file")
53-
parser.add_argument("prompt_text", type=str, help="Prompt text to describe the image")
58+
parser.add_argument("--image_path", type=str, help="Path to the image file")
59+
parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
5460
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
5561
parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
5662
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
63+
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
5764

5865
args = parser.parse_args()
59-
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name)
66+
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)

0 commit comments

Comments
 (0)