Skip to content

Commit 79dbe05

Browse files
committed
batch fine-tuning lmm working
1 parent ce299b3 commit 79dbe05

File tree

5 files changed

+70
-103
lines changed

5 files changed

+70
-103
lines changed

recipes/quickstart/finetuning/datasets/vqa_dataset.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def tokenize_dialogs(dialogs, images, processor):
2323
text_prompt = processor.apply_chat_template(dialogs)
2424
#print("text_prompt",text_prompt)
2525
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
26-
batch["labels"] = copy.copy(batch["input_ids"])
26+
label_list = []
2727
for i in range(len(batch["input_ids"])):
2828
dialog_tokens = batch["input_ids"][i].tolist()
2929
labels = copy.copy(dialog_tokens)
@@ -42,14 +42,62 @@ def tokenize_dialogs(dialogs, images, processor):
4242
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
4343
assistant_header_seq = [128006, 78191, 128007]
4444
labels = replace_target(assistant_header_seq,labels)
45-
batch["labels"][i] = torch.tensor(labels)
45+
label_list.append(labels)
46+
batch["labels"] = torch.tensor(label_list)
47+
tokenizer_length = len(processor.tokenizer)
4648
return batch
4749

50+
def tokenize_dialog(dialog, images, processor):
51+
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
52+
text_prompt = processor.apply_chat_template(dialog)
53+
#print("text_prompt",text_prompt)
54+
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
55+
labels = copy.copy(batch["input_ids"].tolist()[0])
56+
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
57+
last_idx = 0
58+
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
59+
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
60+
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
61+
for n, idx in enumerate(eot_indices):
62+
current_seq = labels[last_idx:idx+1]
63+
if check_header(prompt_header_seqs,current_seq):
64+
# found prompt header, indicating that this seq should be masked
65+
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
66+
else:
67+
last_idx = idx+1
68+
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
69+
assistant_header_seq = [128006, 78191, 128007]
70+
labels = replace_target(assistant_header_seq,labels)
71+
#print("labels",labels)
72+
# print("pixel_values .shape",batch["pixel_values"].shape)
73+
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
74+
75+
batch["labels"] = torch.tensor(labels)
76+
# exit()
77+
# combined_tokens = {
78+
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
79+
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
80+
# "input_ids": dialog_tokens,
81+
# "labels": labels,
82+
# "attention_mask": [1]*len(dialog_tokens),
83+
# "pixel_values": batch["pixel_values"],
84+
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
85+
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
86+
# "cross_attention_mask": batch["cross_attention_mask"]
87+
# }
88+
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
89+
# labels = list(itertools.chain(*(t for t in labels_tokens))),
90+
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
91+
# pixel_values = batch["pixel_values"],
92+
# image_sizes = batch["image_sizes"]
93+
# print("combined_tokens",combined_tokens[image_sizes])
94+
95+
return batch
4896
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
4997
# load_dataset will return DatasetDict that contains all the data in the train set
5098
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
5199
dataset = dataset_dict[split]
52-
dataset = dataset.select(range(100))
100+
dataset = dataset.select(range(500))
53101
return dataset
54102

55103
class VQADataCollator:
@@ -74,5 +122,20 @@ def __call__(self, samples):
74122
dialogs.append(dialog)
75123
images.append(image)
76124
return tokenize_dialogs(dialogs,images, self.processor)
125+
def __callworking__(self, samples):
126+
for sample in samples:
127+
image,sample_text = sample["images"],sample["messages"]
128+
dialog = []
129+
for line in sample_text:
130+
content = []
131+
messages = line["content"]
132+
role = line["role"]
133+
for message in messages:
134+
if message["type"] == "image":
135+
content.append({"type": "image"})
136+
elif message["type"] == "text":
137+
content.append({"type": "text", "text": message["text"].strip()})
138+
dialog.append({"role": role,"content":content})
139+
return tokenize_dialog(dialog,image, self.processor)
77140
def get_data_collator(processor):
78141
return VQADataCollator(processor)

recipes/quickstart/finetuning/datasets/vqa_dataset_old.py

Lines changed: 0 additions & 94 deletions
This file was deleted.

recipes/quickstart/finetuning/finetune_vision_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ an example of the dataset looks like this:
2020

2121
Full-finetune
2222
```bash
23-
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 1 --batch_size_training 1 --model_name llava-hf/llama3-llava-next-8b-hf --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --use-wandb --run_validation True
23+
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name nltpt/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_11bmodel --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding --use-wandb
2424
```
2525

2626
LoRA:

src/llama_recipes/finetuning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def main(**kwargs):
273273
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
274274

275275
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
276+
if custom_data_collator:
277+
val_dl_kwargs["collate_fn"] = custom_data_collator
276278

277279
eval_dataloader = torch.utils.data.DataLoader(
278280
dataset_val,

src/llama_recipes/utils/train_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
146146
else:
147147
batch[key] = batch[key].to(local_rank)
148148
else:
149-
150149
if is_xpu_available():
151150
batch[key] = batch[key].to('xpu:0')
152-
else:
151+
elif torch.cuda.is_available():
153152
batch[key] = batch[key].to('cuda:0')
154153
with autocast():
155154
assert(next(model.parameters()).device == batch['input_ids'].device)
156-
#print("batch: ", batch)
157-
pixel_values = batch['pixel_values']
158-
print("pixel_values.shape input",pixel_values.shape)
159155
loss = model(**batch).loss
160156
loss = loss / gradient_accumulation_steps
161157
#print("loss",loss)

0 commit comments

Comments
 (0)