4
4
from PIL import Image as PIL_Image
5
5
import torch
6
6
from transformers import MllamaForConditionalGeneration , MllamaProcessor
7
+ from accelerate import Accelerator
7
8
9
+ accelerator = Accelerator ()
10
+
11
+ device = accelerator .device
8
12
9
13
# Constants
10
14
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@@ -14,8 +18,11 @@ def load_model_and_processor(model_name: str, hf_token: str):
14
18
"""
15
19
Load the model and processor based on the 11B or 90B model.
16
20
"""
17
- model = MllamaForConditionalGeneration .from_pretrained (model_name , device_map = "auto" , torch_dtype = torch .bfloat16 , token = hf_token )
18
- processor = MllamaProcessor .from_pretrained (model_name , token = hf_token )
21
+ model = MllamaForConditionalGeneration .from_pretrained (model_name , torch_dtype = torch .bfloat16 ,use_safetensors = True , device_map = device ,
22
+ token = hf_token )
23
+ processor = MllamaProcessor .from_pretrained (model_name , token = hf_token ,use_safetensors = True )
24
+
25
+ model , processor = accelerator .prepare (model , processor )
19
26
return model , processor
20
27
21
28
@@ -38,7 +45,7 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
38
45
{"role" : "user" , "content" : [{"type" : "image" }, {"type" : "text" , "text" : prompt_text }]}
39
46
]
40
47
prompt = processor .apply_chat_template (conversation , add_generation_prompt = True , tokenize = False )
41
- inputs = processor (image , prompt , return_tensors = "pt" ).to (model . device )
48
+ inputs = processor (image , prompt , return_tensors = "pt" ).to (device )
42
49
output = model .generate (** inputs , temperature = temperature , top_p = top_p , max_new_tokens = 512 )
43
50
return processor .decode (output [0 ])[len (prompt ):]
44
51
@@ -63,4 +70,4 @@ def main(image_path: str, prompt_text: str, temperature: float, top_p: float, mo
63
70
parser .add_argument ("--hf_token" , type = str , required = True , help = "Hugging Face token for authentication" )
64
71
65
72
args = parser .parse_args ()
66
- main (args .image_path , args .prompt_text , args .temperature , args .top_p , args .model_name , args .hf_token )
73
+ main (args .image_path , args .prompt_text , args .temperature , args .top_p , args .model_name , args .hf_token )
0 commit comments