Skip to content

Commit bd22f40

Browse files
committed
changed to aid2 dataset
1 parent 1a76080 commit bd22f40

File tree

4 files changed

+44
-97
lines changed

4 files changed

+44
-97
lines changed

recipes/quickstart/finetuning/datasets/vqa_dataset.py

Lines changed: 21 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -51,57 +51,12 @@ def tokenize_dialogs(dialogs, images, processor):
5151
tokenizer_length = len(processor.tokenizer)
5252
return batch
5353

54-
def tokenize_dialog(dialog, images, processor):
55-
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
56-
text_prompt = processor.apply_chat_template(dialog)
57-
#print("text_prompt",text_prompt)
58-
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
59-
labels = copy.copy(batch["input_ids"].tolist()[0])
60-
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
61-
last_idx = 0
62-
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
63-
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
64-
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
65-
for n, idx in enumerate(eot_indices):
66-
current_seq = labels[last_idx:idx+1]
67-
if check_header(prompt_header_seqs,current_seq):
68-
# found prompt header, indicating that this seq should be masked
69-
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
70-
else:
71-
last_idx = idx+1
72-
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
73-
assistant_header_seq = [128006, 78191, 128007]
74-
labels = replace_target(assistant_header_seq,labels)
75-
#print("labels",labels)
76-
# print("pixel_values .shape",batch["pixel_values"].shape)
77-
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
7854

79-
batch["labels"] = torch.tensor(labels)
80-
# exit()
81-
# combined_tokens = {
82-
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
83-
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
84-
# "input_ids": dialog_tokens,
85-
# "labels": labels,
86-
# "attention_mask": [1]*len(dialog_tokens),
87-
# "pixel_values": batch["pixel_values"],
88-
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
89-
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
90-
# "cross_attention_mask": batch["cross_attention_mask"]
91-
# }
92-
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
93-
# labels = list(itertools.chain(*(t for t in labels_tokens))),
94-
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
95-
# pixel_values = batch["pixel_values"],
96-
# image_sizes = batch["image_sizes"]
97-
# print("combined_tokens",combined_tokens[image_sizes])
98-
99-
return batch
10055
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
10156
# load_dataset will return DatasetDict that contains all the data in the train set
102-
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
103-
dataset = dataset_dict[split]
104-
dataset = dataset.select(range(500))
57+
dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d")
58+
dataset = dataset_dict['train']
59+
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
10560
return dataset
10661

10762
class VQADataCollator:
@@ -111,35 +66,26 @@ def __init__(self, processor):
11166
def __call__(self, samples):
11267
dialogs,images = [],[]
11368
for sample in samples:
114-
image,sample_text = sample["images"],sample["messages"]
69+
image_list,sample_list = sample["images"],sample["texts"]
70+
if len(image_list) > 1:
71+
raise ValueError("Only support one image per sample")
72+
image = image_list[0].convert("RGB") # only use the first image
11573
dialog = []
116-
for line in sample_text:
117-
content = []
118-
messages = line["content"]
119-
role = line["role"]
120-
for message in messages:
121-
if message["type"] == "image":
122-
content.append({"type": "image"})
123-
elif message["type"] == "text":
124-
content.append({"type": "text", "text": message["text"].strip()})
125-
dialog.append({"role": role,"content":content})
74+
for sample_dict in sample_list:
75+
if not dialog:
76+
# only append image to the first sentence
77+
dialog += [
78+
{"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
79+
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
80+
]
81+
82+
else:
83+
dialog += [
84+
{"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
85+
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
86+
]
12687
dialogs.append(dialog)
127-
images.append(image)
88+
images.append([image])
12889
return tokenize_dialogs(dialogs,images, self.processor)
129-
def __callworking__(self, samples):
130-
for sample in samples:
131-
image,sample_text = sample["images"],sample["messages"]
132-
dialog = []
133-
for line in sample_text:
134-
content = []
135-
messages = line["content"]
136-
role = line["role"]
137-
for message in messages:
138-
if message["type"] == "image":
139-
content.append({"type": "image"})
140-
elif message["type"] == "text":
141-
content.append({"type": "text", "text": message["text"].strip()})
142-
dialog.append({"role": role,"content":content})
143-
return tokenize_dialog(dialog,image, self.processor)
14490
def get_data_collator(processor):
14591
return VQADataCollator(processor)
Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## Fine-Tuning Meta Llama Multi Modal Models recipe
2-
Here we discuss fine-tuning Meta Llama 3.2 11B and 90B models.
2+
This recipe steps you through how to finetune a Llama 3.2 vision model on the VQA task using the [the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) dataset.
33

44
### Concepts
55
Model Architecture
@@ -12,18 +12,24 @@ We need have a new processor class added, that will handle the image processing
1212

1313

1414
### Fine-tuning steps
15-
1. Download the dataset:
16-
an example of the dataset looks like this:
17-
2. Processor example looks like this
1815

19-
3. Load the dataset
2016

21-
Full-finetune
17+
For **full finetuning with FSDP**, we can run the following code:
2218
```bash
23-
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name nltpt/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_11bmodel --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding --use-wandb
19+
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding
2420
```
2521

26-
LoRA:
22+
For **LoRA finetuning with FSDP**, we can run the following code:
2723
```bash
28-
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 1 --batch_size_training 1 --model_name llava-hf/llama3-llava-next-8b-hf --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --use-wandb --run_validation True
24+
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora
2925
```
26+
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.
27+
28+
For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
29+
30+
### How to use custom dataset to fine-tune vision model
31+
32+
1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder
33+
2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the dataloading.
34+
3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collartor that can be used by the Pytorch Data Loader.
35+
4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects.

src/llama_recipes/finetuning.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
2323
from torch.optim.lr_scheduler import StepLR
2424
from transformers import (
25+
AutoConfig,
2526
AutoTokenizer,
2627
BitsAndBytesConfig,
2728
LlamaForCausalLM,
@@ -125,7 +126,8 @@ def main(**kwargs):
125126

126127
# Load the pre-trained model and setup its configuration
127128
use_cache = False if train_config.enable_fsdp else None
128-
if "11B" in train_config.model_name or "90B" in train_config.model_name:
129+
config = AutoConfig.from_pretrained(train_config.model_name)
130+
if config.model_type == "mllama":
129131
is_vision = True
130132
model = MllamaForConditionalGeneration.from_pretrained(
131133
train_config.model_name,
@@ -136,7 +138,7 @@ def main(**kwargs):
136138
)
137139
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
138140
processor.tokenizer.padding_side='right'
139-
else:
141+
elif config.model_type == "llama":
140142
is_vision = False
141143
model = LlamaForCausalLM.from_pretrained(
142144
train_config.model_name,
@@ -146,7 +148,8 @@ def main(**kwargs):
146148
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
147149
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
148150
)
149-
print(model)
151+
else:
152+
raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
150153
# Load the tokenizer and add special tokens
151154
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
152155
tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -190,7 +193,6 @@ def main(**kwargs):
190193

191194
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
192195
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
193-
print("FSDP is enabled",my_auto_wrapping_policy)
194196
device_id = 0
195197
if is_xpu_available():
196198
device_id = torch.xpu.current_device()
@@ -218,8 +220,6 @@ def main(**kwargs):
218220
model.to("xpu:0")
219221
elif torch.cuda.is_available():
220222
model.to("cuda")
221-
print("-------------------")
222-
print("FSDP model", model)
223223
dataset_config = generate_dataset_config(train_config, kwargs)
224224
if is_vision:
225225
dataset_processer = processor
@@ -306,8 +306,6 @@ def main(**kwargs):
306306
weight_decay=train_config.weight_decay,
307307
)
308308
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
309-
# Start the training process
310-
311309
results = train(
312310
model,
313311
train_dataloader,

src/llama_recipes/utils/train_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
132132
with profile(train_config,local_rank) as profile_context:
133133
for step, batch in enumerate(train_dataloader):
134134
total_train_steps += 1
135-
#print("batch: ", batch)
136135
# stop when the maximum number of training steps is reached
137136
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
138137
max_steps_reached = True
@@ -151,10 +150,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
151150
elif torch.cuda.is_available():
152151
batch[key] = batch[key].to('cuda:0')
153152
with autocast():
154-
assert(next(model.parameters()).device == batch['input_ids'].device)
155153
loss = model(**batch).loss
156154
loss = loss / gradient_accumulation_steps
157-
#print("loss",loss)
158155
if train_config.save_metrics:
159156
train_step_loss.append(loss.detach().float().item())
160157
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -175,7 +172,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
175172
pbar.update(1)
176173
else:
177174
# regular backpropagation when fp16 is not used
178-
#print("loss123",loss)
179175
loss.backward()
180176
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
181177
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
@@ -364,7 +360,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
364360
# Ensure no gradients are computed for this scope to save memory
365361
with torch.no_grad():
366362
# Forward pass and compute loss
367-
outputs = model(**batch,use_cache=False)
363+
#outputs = model(**batch,use_cache=False)
364+
outputs = model(**batch)
368365
loss = outputs.loss
369366
if train_config.save_metrics:
370367
val_step_loss.append(loss.detach().float().item())

0 commit comments

Comments
 (0)