Skip to content

Commit 799e90e

Browse files
authored
Support converting fine-tuned llama 3.2 vision model to HF format and then local inference (meta-llama#737)
2 parents 82d4049 + 8715e04 commit 799e90e

File tree

4 files changed

+98
-40
lines changed

4 files changed

+98
-40
lines changed

recipes/quickstart/finetuning/finetune_vision_model.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ For **LoRA finetuning with FSDP**, we can run the following code:
2222

2323
For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
2424

25+
For more details about local inference with the fine-tuned checkpoint, please read [Inference with FSDP checkpoints section](https://github.com/meta-llama/llama-recipes/tree/main/recipes/quickstart/inference/local_inference#inference-with-fsdp-checkpoints) to learn how to convert the FSDP weights into a consolidated Hugging Face formatted model for local inference.
26+
2527
### How to use a custom dataset to fine-tune vision model
2628

2729
In order to use a custom dataset, please follow the steps below:

recipes/quickstart/inference/local_inference/README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Local Inference
22

3+
## Hugging face setup
4+
**Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed.
5+
36
## Multimodal Inference
4-
For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library
7+
For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library.
58

6-
The way to run this would be
9+
The way to run this would be:
710
```
8-
python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
11+
python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
912
```
1013

1114
## Text-only Inference
Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import argparse
12
import os
23
import sys
3-
import argparse
4-
from PIL import Image as PIL_Image
4+
55
import torch
6+
from accelerate import Accelerator
7+
from PIL import Image as PIL_Image
68
from transformers import MllamaForConditionalGeneration, MllamaProcessor
7-
from accelerate import Accelerator
89

910
accelerator = Accelerator()
1011

@@ -14,15 +15,19 @@
1415
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
1516

1617

17-
def load_model_and_processor(model_name: str, hf_token: str):
18+
def load_model_and_processor(model_name: str):
1819
"""
1920
Load the model and processor based on the 11B or 90B model.
2021
"""
21-
model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
22-
token=hf_token)
23-
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
22+
model = MllamaForConditionalGeneration.from_pretrained(
23+
model_name,
24+
torch_dtype=torch.bfloat16,
25+
use_safetensors=True,
26+
device_map=device,
27+
)
28+
processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
2429

25-
model, processor=accelerator.prepare(model, processor)
30+
model, processor = accelerator.prepare(model, processor)
2631
return model, processor
2732

2833

@@ -37,37 +42,67 @@ def process_image(image_path: str) -> PIL_Image.Image:
3742
return PIL_Image.open(f).convert("RGB")
3843

3944

40-
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
45+
def generate_text_from_image(
46+
model, processor, image, prompt_text: str, temperature: float, top_p: float
47+
):
4148
"""
4249
Generate text from an image using the model and processor.
4350
"""
4451
conversation = [
45-
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
52+
{
53+
"role": "user",
54+
"content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
55+
}
4656
]
47-
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
57+
prompt = processor.apply_chat_template(
58+
conversation, add_generation_prompt=True, tokenize=False
59+
)
4860
inputs = processor(image, prompt, return_tensors="pt").to(device)
49-
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
50-
return processor.decode(output[0])[len(prompt):]
61+
output = model.generate(
62+
**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
63+
)
64+
return processor.decode(output[0])[len(prompt) :]
5165

5266

53-
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
67+
def main(
68+
image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
69+
):
5470
"""
55-
Call all the functions.
71+
Call all the functions.
5672
"""
57-
model, processor = load_model_and_processor(model_name, hf_token)
73+
model, processor = load_model_and_processor(model_name)
5874
image = process_image(image_path)
59-
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
75+
result = generate_text_from_image(
76+
model, processor, image, prompt_text, temperature, top_p
77+
)
6078
print("Generated Text: " + result)
6179

6280

6381
if __name__ == "__main__":
64-
parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
82+
parser = argparse.ArgumentParser(
83+
description="Generate text from an image and prompt using the 3.2 MM Llama model."
84+
)
6585
parser.add_argument("--image_path", type=str, help="Path to the image file")
66-
parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
67-
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
68-
parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
69-
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
70-
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
86+
parser.add_argument(
87+
"--prompt_text", type=str, help="Prompt text to describe the image"
88+
)
89+
parser.add_argument(
90+
"--temperature",
91+
type=float,
92+
default=0.7,
93+
help="Temperature for generation (default: 0.7)",
94+
)
95+
parser.add_argument(
96+
"--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
97+
)
98+
parser.add_argument(
99+
"--model_name",
100+
type=str,
101+
default=DEFAULT_MODEL,
102+
help=f"Model name (default: '{DEFAULT_MODEL}')",
103+
)
71104

72105
args = parser.parse_args()
73-
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
106+
main(
107+
args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
108+
)

src/llama_recipes/inference/model_utils.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
33

4+
from warnings import warn
5+
6+
from llama_recipes.configs import quantization_config as QUANT_CONFIG
47
from llama_recipes.utils.config_utils import update_config
5-
from llama_recipes.configs import quantization_config as QUANT_CONFIG
68
from peft import PeftModel
7-
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
8-
from warnings import warn
9+
from transformers import (
10+
AutoConfig,
11+
AutoModelForCausalLM,
12+
LlamaConfig,
13+
LlamaForCausalLM,
14+
MllamaConfig,
15+
MllamaForConditionalGeneration,
16+
)
17+
918

1019
# Function to load the main model for text generation
1120
def load_model(model_name, quantization, use_fast_kernels, **kwargs):
1221
if type(quantization) == type(True):
13-
warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
14-
quantization = "8bit"
22+
warn(
23+
"Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
24+
FutureWarning,
25+
)
26+
quantization = "8bit"
1527

1628
bnb_config = None
1729
if quantization:
@@ -23,10 +35,10 @@ def load_model(model_name, quantization, use_fast_kernels, **kwargs):
2335

2436
kwargs = {}
2537
if bnb_config:
26-
kwargs["quantization_config"]=bnb_config
27-
kwargs["device_map"]="auto"
28-
kwargs["low_cpu_mem_usage"]=True
29-
kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None
38+
kwargs["quantization_config"] = bnb_config
39+
kwargs["device_map"] = "auto"
40+
kwargs["low_cpu_mem_usage"] = True
41+
kwargs["attn_implementation"] = "sdpa" if use_fast_kernels else None
3042
model = AutoModelForCausalLM.from_pretrained(
3143
model_name,
3244
return_dict=True,
@@ -40,10 +52,16 @@ def load_peft_model(model, peft_model):
4052
peft_model = PeftModel.from_pretrained(model, peft_model)
4153
return peft_model
4254

55+
4356
# Loading the model from config to load FSDP checkpoints into that
4457
def load_llama_from_config(config_path):
45-
model_config = LlamaConfig.from_pretrained(config_path)
46-
model = LlamaForCausalLM(config=model_config)
58+
config = AutoConfig.from_pretrained(config_path)
59+
if config.model_type == "mllama":
60+
model = MllamaForConditionalGeneration(config=config)
61+
elif config.model_type == "llama":
62+
model = LlamaForCausalLM(config=config)
63+
else:
64+
raise ValueError(
65+
f"Unsupported model type: {config.model_type}, Please use llama or mllama model."
66+
)
4767
return model
48-
49-

0 commit comments

Comments
 (0)