Skip to content

Commit ad5ce80

Browse files
authored
[Fixed] RuntimeError: probability tensor contains either inf, nan or element < 0 (meta-llama#704)
2 parents d9aab46 + 625860d commit ad5ce80

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

recipes/quickstart/inference/local_inference/multi_modal_infer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from PIL import Image as PIL_Image
55
import torch
66
from transformers import MllamaForConditionalGeneration, MllamaProcessor
7+
from accelerate import Accelerator
78

9+
accelerator = Accelerator()
10+
11+
device = accelerator.device
812

913
# Constants
1014
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@@ -14,8 +18,11 @@ def load_model_and_processor(model_name: str, hf_token: str):
1418
"""
1519
Load the model and processor based on the 11B or 90B model.
1620
"""
17-
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
18-
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
21+
model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
22+
token=hf_token)
23+
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
24+
25+
model, processor=accelerator.prepare(model, processor)
1926
return model, processor
2027

2128

@@ -38,7 +45,7 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
3845
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
3946
]
4047
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
41-
inputs = processor(image, prompt, return_tensors="pt").to(model.device)
48+
inputs = processor(image, prompt, return_tensors="pt").to(device)
4249
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
4350
return processor.decode(output[0])[len(prompt):]
4451

@@ -63,4 +70,4 @@ def main(image_path: str, prompt_text: str, temperature: float, top_p: float, mo
6370
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
6471

6572
args = parser.parse_args()
66-
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
73+
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)

0 commit comments

Comments
 (0)