Skip to content

Commit ce299b3

Browse files
committed
add get_custom_data_collator feature
1 parent 12da109 commit ce299b3

File tree

7 files changed

+213
-102
lines changed

7 files changed

+213
-102
lines changed

recipes/quickstart/finetuning/datasets/vqa_dataset.py

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datasets import load_dataset
77
import itertools
88
import torch
9+
910
# check system prompt token seq or user prompt token seq is in the current token list
1011
def check_header(targets,seq):
1112
for i in range(len(seq)-3):
@@ -17,78 +18,61 @@ def replace_target(target,seq):
1718
if seq[i:i+3] == target:
1819
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
1920
return seq
20-
def tokenize_dialog(dialog, images, processor):
21+
def tokenize_dialogs(dialogs, images, processor):
2122
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
22-
text_prompt = processor.apply_chat_template(dialog)
23+
text_prompt = processor.apply_chat_template(dialogs)
2324
#print("text_prompt",text_prompt)
24-
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
25-
labels = copy.copy(batch["input_ids"].tolist()[0])
26-
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
27-
last_idx = 0
28-
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
29-
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
30-
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
31-
for n, idx in enumerate(eot_indices):
32-
current_seq = labels[last_idx:idx+1]
33-
if check_header(prompt_header_seqs,current_seq):
34-
# found prompt header, indicating that this seq should be masked
35-
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
36-
else:
37-
last_idx = idx+1
38-
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
39-
assistant_header_seq = [128006, 78191, 128007]
40-
labels = replace_target(assistant_header_seq,labels)
41-
#print("labels",labels)
42-
# print("pixel_values .shape",batch["pixel_values"].shape)
43-
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
44-
45-
batch["labels"] = torch.tensor(labels)
46-
#pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
47-
batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
48-
# pixel_values .shape torch.Size([1, 4, 3, 560, 560])
49-
print("pixel_values .shape",batch["pixel_values"].shape)
50-
# exit()
51-
# combined_tokens = {
52-
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
53-
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
54-
# "input_ids": dialog_tokens,
55-
# "labels": labels,
56-
# "attention_mask": [1]*len(dialog_tokens),
57-
# "pixel_values": batch["pixel_values"],
58-
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
59-
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
60-
# "cross_attention_mask": batch["cross_attention_mask"]
61-
# }
62-
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
63-
# labels = list(itertools.chain(*(t for t in labels_tokens))),
64-
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
65-
# pixel_values = batch["pixel_values"],
66-
# image_sizes = batch["image_sizes"]
67-
# print("combined_tokens",combined_tokens[image_sizes])
68-
25+
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
26+
batch["labels"] = copy.copy(batch["input_ids"])
27+
for i in range(len(batch["input_ids"])):
28+
dialog_tokens = batch["input_ids"][i].tolist()
29+
labels = copy.copy(dialog_tokens)
30+
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
31+
last_idx = 0
32+
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
33+
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
34+
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
35+
for n, idx in enumerate(eot_indices):
36+
current_seq = labels[last_idx:idx+1]
37+
if check_header(prompt_header_seqs,current_seq):
38+
# found prompt header, indicating that this seq should be masked
39+
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
40+
else:
41+
last_idx = idx+1
42+
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
43+
assistant_header_seq = [128006, 78191, 128007]
44+
labels = replace_target(assistant_header_seq,labels)
45+
batch["labels"][i] = torch.tensor(labels)
6946
return batch
70-
def image_tokenize(sample, processor):
71-
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
72-
images,sample_text = sample["images"],sample["messages"]
73-
dialog = []
74-
for line in sample_text:
75-
content = []
76-
messages = line["content"]
77-
role = line["role"]
78-
for message in messages:
79-
if message["type"] == "image":
80-
content.append({"type": "image"})
81-
elif message["type"] == "text":
82-
content.append({"type": "text", "text": message["text"].strip()})
83-
dialog.append({"role": role,"content":content})
84-
return tokenize_dialog(dialog,images, processor)
85-
8647

8748
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
8849
# load_dataset will return DatasetDict that contains all the data in the train set
8950
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
9051
dataset = dataset_dict[split]
9152
dataset = dataset.select(range(100))
92-
tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
93-
tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
94-
return tokenized_datasets
53+
return dataset
54+
55+
class VQADataCollator:
56+
def __init__(self, processor):
57+
self.processor = processor
58+
self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
59+
def __call__(self, samples):
60+
dialogs,images = [],[]
61+
for sample in samples:
62+
image,sample_text = sample["images"],sample["messages"]
63+
dialog = []
64+
for line in sample_text:
65+
content = []
66+
messages = line["content"]
67+
role = line["role"]
68+
for message in messages:
69+
if message["type"] == "image":
70+
content.append({"type": "image"})
71+
elif message["type"] == "text":
72+
content.append({"type": "text", "text": message["text"].strip()})
73+
dialog.append({"role": role,"content":content})
74+
dialogs.append(dialog)
75+
images.append(image)
76+
return tokenize_dialogs(dialogs,images, self.processor)
77+
def get_data_collator(processor):
78+
return VQADataCollator(processor)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
3+
4+
5+
import copy
6+
from datasets import load_dataset
7+
import itertools
8+
import torch
9+
# check system prompt token seq or user prompt token seq is in the current token list
10+
def check_header(targets,seq):
11+
for i in range(len(seq)-3):
12+
if seq[i:i+3] in targets:
13+
return True
14+
return False
15+
def replace_target(target,seq):
16+
for i in range(len(seq)-3):
17+
if seq[i:i+3] == target:
18+
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
19+
return seq
20+
def tokenize_dialog(dialog, images, processor):
21+
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
22+
text_prompt = processor.apply_chat_template(dialog)
23+
#print("text_prompt",text_prompt)
24+
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
25+
labels = copy.copy(batch["input_ids"].tolist()[0])
26+
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
27+
last_idx = 0
28+
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
29+
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
30+
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
31+
for n, idx in enumerate(eot_indices):
32+
current_seq = labels[last_idx:idx+1]
33+
if check_header(prompt_header_seqs,current_seq):
34+
# found prompt header, indicating that this seq should be masked
35+
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
36+
else:
37+
last_idx = idx+1
38+
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
39+
assistant_header_seq = [128006, 78191, 128007]
40+
labels = replace_target(assistant_header_seq,labels)
41+
#print("labels",labels)
42+
# print("pixel_values .shape",batch["pixel_values"].shape)
43+
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
44+
45+
batch["labels"] = torch.tensor(labels)
46+
#pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
47+
batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
48+
# pixel_values .shape torch.Size([1, 4, 3, 560, 560])
49+
print("pixel_values .shape",batch["pixel_values"].shape)
50+
# exit()
51+
# combined_tokens = {
52+
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
53+
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
54+
# "input_ids": dialog_tokens,
55+
# "labels": labels,
56+
# "attention_mask": [1]*len(dialog_tokens),
57+
# "pixel_values": batch["pixel_values"],
58+
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
59+
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
60+
# "cross_attention_mask": batch["cross_attention_mask"]
61+
# }
62+
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
63+
# labels = list(itertools.chain(*(t for t in labels_tokens))),
64+
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
65+
# pixel_values = batch["pixel_values"],
66+
# image_sizes = batch["image_sizes"]
67+
# print("combined_tokens",combined_tokens[image_sizes])
68+
69+
return batch
70+
def image_tokenize(sample, processor):
71+
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
72+
images,sample_text = sample["images"],sample["messages"]
73+
dialog = []
74+
for line in sample_text:
75+
content = []
76+
messages = line["content"]
77+
role = line["role"]
78+
for message in messages:
79+
if message["type"] == "image":
80+
content.append({"type": "image"})
81+
elif message["type"] == "text":
82+
content.append({"type": "text", "text": message["text"].strip()})
83+
dialog.append({"role": role,"content":content})
84+
return tokenize_dialog(dialog,images, processor)
85+
86+
87+
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
88+
# load_dataset will return DatasetDict that contains all the data in the train set
89+
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
90+
dataset = dataset_dict[split]
91+
dataset = dataset.select(range(100))
92+
tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
93+
tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
94+
return tokenized_datasets

src/llama_recipes/datasets/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
77
from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
8-
from llama_recipes.datasets.custom_dataset import get_custom_dataset
8+
from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
99
from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
1010
from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
11-
1211
DATASET_PREPROC = {
1312
"alpaca_dataset": partial(get_alpaca_dataset),
1413
"grammar_dataset": get_grammar_dataset,
1514
"samsum_dataset": get_samsum_dataset,
1615
"custom_dataset": get_custom_dataset,
1716
"llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
18-
}
17+
}
18+
DATALOADER_COLLATE_FUNC = {
19+
"custom_dataset": get_data_collator
20+
}

src/llama_recipes/datasets/custom_dataset.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
3535
print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
3636
raise e
3737

38+
def get_data_collator(dataset_processer,dataset_config):
39+
if ":" in dataset_config.file:
40+
module_path, func_name = dataset_config.file.split(":")
41+
else:
42+
module_path, func_name = dataset_config.file, "get_data_collator"
43+
44+
if not module_path.endswith(".py"):
45+
raise ValueError(f"Dataset file {module_path} is not a .py file.")
46+
47+
module_path = Path(module_path)
48+
if not module_path.is_file():
49+
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
50+
51+
module = load_module_from_py_file(module_path.as_posix())
52+
try:
53+
return getattr(module, func_name)(dataset_processer)
54+
except AttributeError as e:
55+
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
56+
print("Using the default data_collator instead.")
57+
return None

src/llama_recipes/finetuning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
get_dataloader_kwargs,
4646
check_fsdp_config,
4747
)
48-
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
48+
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
4949

5050
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
5151
from llama_recipes.utils.train_utils import (
@@ -252,8 +252,12 @@ def main(**kwargs):
252252
if train_config.batching_strategy == "packing":
253253
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
254254

255-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
255+
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
256256
print("length of dataset_train", len(dataset_train))
257+
custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
258+
if custom_data_collator:
259+
print("custom_data_collator is used")
260+
train_dl_kwargs["collate_fn"] = custom_data_collator
257261
# Create DataLoaders for the training and validation dataset
258262
train_dataloader = torch.utils.data.DataLoader(
259263
dataset_train,
@@ -268,7 +272,7 @@ def main(**kwargs):
268272
if train_config.batching_strategy == "packing":
269273
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
270274

271-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
275+
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
272276

273277
eval_dataloader = torch.utils.data.DataLoader(
274278
dataset_val,

src/llama_recipes/utils/config_utils.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
)
1515
from transformers import default_data_collator
1616
from transformers.data import DataCollatorForSeq2Seq
17+
from functools import partial
1718

1819
from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
1920
from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
20-
from llama_recipes.utils.dataset_utils import DATASET_PREPROC
21-
21+
from llama_recipes.datasets import DATASET_PREPROC
2222

2323
def update_config(config, **kwargs):
2424
if isinstance(config, (tuple, list)):
@@ -76,39 +76,36 @@ def generate_dataset_config(train_config, kwargs):
7676
return dataset_config
7777

7878

79-
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode,collate_fn=None):
80-
kwargs = {}
81-
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
82-
if train_config.batching_strategy == "padding":
83-
if train_config.enable_fsdp:
84-
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
85-
dataset,
86-
batch_size=batch_size,
87-
rank=dist.get_rank(),
88-
num_replicas=dist.get_world_size(),
89-
shuffle=mode=="train",
90-
)
91-
else:
92-
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
93-
if not collate_fn:
94-
kwargs["collate_fn"] = collate_fn
95-
else:
96-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
97-
elif train_config.batching_strategy == "packing":
98-
if train_config.enable_fsdp:
99-
kwargs["sampler"] = DistributedSampler(
79+
def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
80+
kwargs = {}
81+
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
82+
if train_config.batching_strategy == "padding":
83+
if train_config.enable_fsdp:
84+
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
10085
dataset,
86+
batch_size=batch_size,
10187
rank=dist.get_rank(),
10288
num_replicas=dist.get_world_size(),
10389
shuffle=mode=="train",
104-
drop_last=True,
10590
)
106-
kwargs["batch_size"] = batch_size
107-
kwargs["drop_last"] = True
108-
kwargs["collate_fn"] = default_data_collator
10991
else:
110-
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
111-
return kwargs
92+
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
93+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
94+
elif train_config.batching_strategy == "packing":
95+
if train_config.enable_fsdp:
96+
kwargs["sampler"] = DistributedSampler(
97+
dataset,
98+
rank=dist.get_rank(),
99+
num_replicas=dist.get_world_size(),
100+
shuffle=mode=="train",
101+
drop_last=True,
102+
)
103+
kwargs["batch_size"] = batch_size
104+
kwargs["drop_last"] = True
105+
kwargs["collate_fn"] = default_data_collator
106+
else:
107+
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
108+
return kwargs
112109

113110

114111
def check_fsdp_config(fsdp_config):

0 commit comments

Comments
 (0)