Skip to content

Commit e1bbffc

Browse files
authored
Merge pull request meta-llama#13 from meta-llama/lmm_finetune
add vision model finetune recipe
2 parents 3e39ed0 + 57afa0b commit e1bbffc

File tree

11 files changed

+256
-62
lines changed

11 files changed

+256
-62
lines changed

.github/scripts/spellcheck_conf/wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,3 +1451,6 @@ openhathi
14511451
sarvam
14521452
subtask
14531453
acc
1454+
OCRVQA
1455+
OCRVQADataCollator
1456+
ocrvqa
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
3+
4+
5+
import copy
6+
from datasets import load_dataset
7+
import itertools
8+
import torch
9+
10+
# check system prompt token seq or user prompt token seq is in the current token list
11+
def check_header(targets,seq):
12+
for i in range(len(seq)-3):
13+
if seq[i:i+3] in targets:
14+
return True
15+
return False
16+
def replace_target(target,seq):
17+
for i in range(len(seq)-3):
18+
if seq[i:i+3] == target:
19+
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
20+
return seq
21+
def tokenize_dialogs(dialogs, images, processor):
22+
text_prompt = processor.apply_chat_template(dialogs)
23+
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
24+
label_list = []
25+
for i in range(len(batch["input_ids"])):
26+
dialog_tokens = batch["input_ids"][i].tolist()
27+
labels = copy.copy(dialog_tokens)
28+
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
29+
last_idx = 0
30+
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
31+
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
32+
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
33+
for n, idx in enumerate(eot_indices):
34+
current_seq = labels[last_idx:idx+1]
35+
if check_header(prompt_header_seqs,current_seq):
36+
# found prompt header, indicating that this seq should be masked
37+
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
38+
else:
39+
last_idx = idx+1
40+
# Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
41+
assistant_header_seq = [128006, 78191, 128007]
42+
labels = replace_target(assistant_header_seq,labels)
43+
# Mask the padding token and image token 128256
44+
for i in range(len(labels)):
45+
if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index
46+
labels[i] = -100
47+
label_list.append(labels)
48+
batch["labels"] = torch.tensor(label_list)
49+
return batch
50+
51+
52+
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
53+
# load_dataset will return DatasetDict that contains all the data in the train set
54+
dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
55+
dataset = dataset_dict['train']
56+
# Comment out the following line to use the full dataset, for quick testing only use 2000 samples
57+
dataset = dataset.select(range(2000))
58+
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
59+
return dataset
60+
61+
class OCRVQADataCollator:
62+
def __init__(self, processor):
63+
self.processor = processor
64+
self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
65+
def __call__(self, samples):
66+
dialogs,images = [],[]
67+
for sample in samples:
68+
image_list,sample_list = sample["images"],sample["texts"]
69+
if len(image_list) > 1:
70+
raise ValueError("Only support one image per sample")
71+
image = image_list[0].convert("RGB") # only use the first image
72+
dialog = []
73+
for sample_dict in sample_list:
74+
if not dialog:
75+
# only append image to the first sentence
76+
dialog += [
77+
{"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
78+
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
79+
]
80+
81+
else:
82+
dialog += [
83+
{"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
84+
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
85+
]
86+
dialogs.append(dialog)
87+
images.append([image])
88+
return tokenize_dialogs(dialogs,images, self.processor)
89+
def get_data_collator(processor):
90+
return OCRVQADataCollator(processor)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
## Fine-Tuning Meta Llama Multi Modal Models recipe
2+
This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset.
3+
4+
**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset only for demonstration purposes of the required steps for fine-tuning our vision models with llama-recipes.
5+
6+
### Fine-tuning steps
7+
8+
We created an example script [ocrvqa_dataset.py](./datasets/ocrvqa_dataset.py) that can load the OCRVQA dataset with `get_custom_dataset` function, then provide OCRVQADataCollator class to process the image dataset.
9+
10+
For **full finetuning with FSDP**, we can run the following code:
11+
12+
```bash
13+
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding
14+
```
15+
16+
For **LoRA finetuning with FSDP**, we can run the following code:
17+
18+
```bash
19+
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora
20+
```
21+
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.
22+
23+
For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
24+
25+
### How to use a custom dataset to fine-tune vision model
26+
27+
In order to use a custom dataset, please follow the steps below:
28+
29+
1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder.
30+
2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the data loading.
31+
3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collator that can be used by the Pytorch Data Loader.
32+
4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects.
33+
5. Run the `torchrun` commend from above section, please change the `--custom_dataset.file` to the new dataset python file, adjust the learning rate accordingly.

src/llama_recipes/datasets/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
77
from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
8-
from llama_recipes.datasets.custom_dataset import get_custom_dataset
8+
from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
99
from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
1010
from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
11-
1211
DATASET_PREPROC = {
1312
"alpaca_dataset": partial(get_alpaca_dataset),
1413
"grammar_dataset": get_grammar_dataset,
1514
"samsum_dataset": get_samsum_dataset,
1615
"custom_dataset": get_custom_dataset,
1716
"llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
18-
}
17+
}
18+
DATALOADER_COLLATE_FUNC = {
19+
"custom_dataset": get_data_collator
20+
}

src/llama_recipes/datasets/custom_dataset.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
3535
print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
3636
raise e
3737

38+
def get_data_collator(dataset_processer,dataset_config):
39+
if ":" in dataset_config.file:
40+
module_path, func_name = dataset_config.file.split(":")
41+
else:
42+
module_path, func_name = dataset_config.file, "get_data_collator"
43+
44+
if not module_path.endswith(".py"):
45+
raise ValueError(f"Dataset file {module_path} is not a .py file.")
46+
47+
module_path = Path(module_path)
48+
if not module_path.is_file():
49+
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
50+
51+
module = load_module_from_py_file(module_path.as_posix())
52+
try:
53+
return getattr(module, func_name)(dataset_processer)
54+
except AttributeError as e:
55+
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
56+
print("Using the default data_collator instead.")
57+
return None

src/llama_recipes/finetuning.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
FullyShardedDataParallel as FSDP,
1515
ShardingStrategy
1616
)
17-
1817
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
1918
from torch.optim.lr_scheduler import StepLR
2019
from transformers import (
20+
AutoConfig,
2121
AutoTokenizer,
2222
BitsAndBytesConfig,
23-
LlamaForCausalLM,
24-
LlamaConfig,
23+
AutoProcessor,
24+
MllamaForConditionalGeneration,
25+
AutoModel,
2526
)
2627
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
28+
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
2729

2830
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
2931
from llama_recipes.configs import train_config as TRAIN_CONFIG
@@ -39,7 +41,7 @@
3941
get_dataloader_kwargs,
4042
check_fsdp_config,
4143
)
42-
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
44+
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
4345

4446
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
4547
from llama_recipes.utils.train_utils import (
@@ -118,19 +120,35 @@ def main(**kwargs):
118120

119121
# Load the pre-trained model and setup its configuration
120122
use_cache = False if train_config.enable_fsdp else None
121-
model = LlamaForCausalLM.from_pretrained(
123+
config = AutoConfig.from_pretrained(train_config.model_name)
124+
if config.model_type == "mllama":
125+
is_vision = True
126+
model = MllamaForConditionalGeneration.from_pretrained(
122127
train_config.model_name,
123128
quantization_config=bnb_config,
124-
use_cache=use_cache,
125129
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
126130
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
127131
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
128132
)
129-
133+
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
134+
processor.tokenizer.padding_side='right'
135+
elif config.model_type == "llama":
136+
is_vision = False
137+
model = AutoModel.from_pretrained(
138+
train_config.model_name,
139+
quantization_config=bnb_config,
140+
use_cache=use_cache,
141+
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
142+
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
143+
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
144+
)
145+
else:
146+
raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
130147
# Load the tokenizer and add special tokens
131148
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
132-
tokenizer.pad_token_id = tokenizer.eos_token_id
133-
149+
if not tokenizer.pad_token_id:
150+
tokenizer.pad_token_id = tokenizer.eos_token_id
151+
134152
# If there is a mismatch between tokenizer vocab size and embedding matrix,
135153
# throw a warning and then expand the embedding matrix
136154
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -169,8 +187,12 @@ def main(**kwargs):
169187
freeze_transformer_layers(model, train_config.num_freeze_layers)
170188

171189
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
172-
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
173-
190+
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
191+
if is_vision:
192+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
193+
else:
194+
# Create the FSDP wrapper for LlamaDecoderLayer in text models
195+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
174196
device_id = 0
175197
if is_xpu_available():
176198
device_id = torch.xpu.current_device()
@@ -198,52 +220,70 @@ def main(**kwargs):
198220
model.to("xpu:0")
199221
elif torch.cuda.is_available():
200222
model.to("cuda")
201-
202223
dataset_config = generate_dataset_config(train_config, kwargs)
224+
if is_vision:
225+
dataset_processer = processor
226+
else:
227+
dataset_processer = tokenizer
228+
229+
# Load and preprocess the dataset for training and validation
203230

204-
# Load and preprocess the dataset for training and validation
205231
dataset_train = get_preprocessed_dataset(
206-
tokenizer,
232+
dataset_processer,
207233
dataset_config,
208234
split="train",
209235
)
210236
if not train_config.enable_fsdp or rank == 0:
211237
print(f"--> Training Set Length = {len(dataset_train)}")
212238

213239
dataset_val = get_preprocessed_dataset(
214-
tokenizer,
240+
dataset_processer,
215241
dataset_config,
216242
split="test",
217243
)
218244
if not train_config.enable_fsdp or rank == 0:
219245
print(f"--> Validation Set Length = {len(dataset_val)}")
220246

221247
if train_config.batching_strategy == "packing":
222-
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
223-
224-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
225-
248+
if is_vision:
249+
raise ValueError("Packing is not supported for vision datasets")
250+
else:
251+
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
252+
253+
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
254+
print("length of dataset_train", len(dataset_train))
255+
custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
256+
if custom_data_collator:
257+
print("custom_data_collator is used")
258+
train_dl_kwargs["collate_fn"] = custom_data_collator
226259
# Create DataLoaders for the training and validation dataset
227260
train_dataloader = torch.utils.data.DataLoader(
228261
dataset_train,
229262
num_workers=train_config.num_workers_dataloader,
230263
pin_memory=True,
231264
**train_dl_kwargs,
232265
)
266+
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
233267

234268
eval_dataloader = None
235269
if train_config.run_validation:
236270
if train_config.batching_strategy == "packing":
237-
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
271+
if is_vision:
272+
raise ValueError("Packing is not supported for vision datasets")
273+
else:
274+
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
238275

239-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
276+
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
277+
if custom_data_collator:
278+
val_dl_kwargs["collate_fn"] = custom_data_collator
240279

241280
eval_dataloader = torch.utils.data.DataLoader(
242281
dataset_val,
243282
num_workers=train_config.num_workers_dataloader,
244283
pin_memory=True,
245284
**val_dl_kwargs,
246285
)
286+
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
247287
if len(eval_dataloader) == 0:
248288
raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
249289
else:
@@ -266,7 +306,6 @@ def main(**kwargs):
266306
weight_decay=train_config.weight_decay,
267307
)
268308
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
269-
# Start the training process
270309
results = train(
271310
model,
272311
train_dataloader,

src/llama_recipes/policies/wrapping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import functools
55

66
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
7+
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
8+
79
from torch.distributed.fsdp.wrap import (
810
transformer_auto_wrap_policy,
911
size_based_auto_wrap_policy,
@@ -25,9 +27,7 @@ def get_llama_wrapper():
2527

2628
llama_auto_wrap_policy = functools.partial(
2729
transformer_auto_wrap_policy,
28-
transformer_layer_cls={
29-
LlamaDecoderLayer,
30-
},
30+
transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
3131
)
3232

3333
return llama_auto_wrap_policy

0 commit comments

Comments
 (0)