Skip to content

Commit b566582

Browse files
committed
finetune not working with fsdp
1 parent 1d90dbe commit b566582

File tree

5 files changed

+114
-11
lines changed

5 files changed

+114
-11
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
3+
4+
5+
import copy
6+
from datasets import load_dataset
7+
import itertools
8+
# check system prompt token seq or user prompt token seq is in the current token list
9+
def check_header(targets,seq):
10+
for i in range(len(seq)-3):
11+
if seq[i:i+3] in targets:
12+
return True
13+
return False
14+
def replace_target(target,seq):
15+
for i in range(len(seq)-3):
16+
if seq[i:i+3] == target:
17+
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
18+
return seq
19+
def tokenize_dialog(dialog, images, processor):
20+
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
21+
text_prompt = processor.apply_chat_template(dialog)
22+
#print("text_prompt",text_prompt)
23+
batch = processor(images=images, text=text_prompt)
24+
dialog_tokens = batch["input_ids"].tolist()[0]
25+
#print("dialog_tokens",dialog_tokens)
26+
#print("dialog_tokens",dialog_tokens)
27+
attention_mask = batch["attention_mask"].tolist()[0]
28+
#print("attention_mask",attention_mask)
29+
labels = copy.copy(dialog_tokens)
30+
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
31+
last_idx = 0
32+
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
33+
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
34+
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
35+
for n, idx in enumerate(eot_indices):
36+
current_seq = labels[last_idx:idx+1]
37+
if check_header(prompt_header_seqs,current_seq):
38+
# found prompt header, indicating that this seq should be masked
39+
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
40+
else:
41+
last_idx = idx
42+
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
43+
assistant_header_seq = [128006, 78191, 128007]
44+
labels = replace_target(assistant_header_seq,labels)
45+
#print("labels",labels)
46+
47+
48+
combined_tokens = {
49+
# "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
50+
# "labels": list(itertools.chain(*(t for t in labels_tokens))),
51+
"input_ids": dialog_tokens,
52+
"labels": labels,
53+
"attention_mask": [1]*len(dialog_tokens),
54+
"pixel_values": batch["pixel_values"].tolist()[0],
55+
"image_sizes": batch["image_sizes"].tolist()[0]
56+
}
57+
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
58+
# labels = list(itertools.chain(*(t for t in labels_tokens))),
59+
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
60+
# pixel_values = batch["pixel_values"],
61+
# image_sizes = batch["image_sizes"]
62+
# print("combined_tokens",combined_tokens[image_sizes])
63+
64+
return combined_tokens
65+
def image_tokenize(sample, processor):
66+
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
67+
images,sample_text = sample["images"],sample["messages"]
68+
dialog = []
69+
for line in sample_text:
70+
content = []
71+
messages = line["content"]
72+
role = line["role"]
73+
for message in messages:
74+
if message["type"] == "image":
75+
content.append({"type": "image"})
76+
elif message["type"] == "text":
77+
content.append({"type": "text", "text": message["text"].strip()})
78+
dialog.append({"role": role,"content":content})
79+
return tokenize_dialog(dialog,images, processor)
80+
81+
82+
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
83+
# load_dataset will return DatasetDict that contains all the data in the train set
84+
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
85+
dataset = dataset_dict[split]
86+
dataset = dataset.select(range(100))
87+
tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
88+
tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
89+
return tokenized_datasets

src/llama_recipes/configs/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ class custom_dataset:
3737
class llamaguard_toxicchat_dataset:
3838
dataset: str = "llamaguard_toxicchat_dataset"
3939
train_split: str = "train"
40-
test_split: str = "test"
40+
test_split: str = "test"

src/llama_recipes/finetuning.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
BitsAndBytesConfig,
2323
LlamaForCausalLM,
2424
LlamaConfig,
25+
AutoConfig,
26+
AutoModel,
27+
LlavaNextForConditionalGeneration,
28+
LlavaNextProcessor
29+
2530
)
2631
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2732

@@ -116,11 +121,11 @@ def main(**kwargs):
116121
bnb_config = quant_config.create_bnb_config(train_config.quantization)
117122

118123
# Load the pre-trained model and setup its configuration
119-
use_cache = False if train_config.enable_fsdp else None
120-
model = LlamaForCausalLM.from_pretrained(
124+
#use_cache = False if train_config.enable_fsdp else None
125+
model = LlavaNextForConditionalGeneration.from_pretrained(
121126
train_config.model_name,
122127
quantization_config=bnb_config,
123-
use_cache=use_cache,
128+
# use_cache=use_cache,
124129
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
125130
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
126131
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -129,7 +134,8 @@ def main(**kwargs):
129134
# Load the tokenizer and add special tokens
130135
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
131136
tokenizer.pad_token_id = tokenizer.eos_token_id
132-
137+
processor = LlavaNextProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
138+
processor.tokenizer.padding_side='right'
133139
# If there is a mismatch between tokenizer vocab size and embedding matrix,
134140
# throw a warning and then expand the embedding matrix
135141
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -200,15 +206,15 @@ def main(**kwargs):
200206

201207
# Load and preprocess the dataset for training and validation
202208
dataset_train = get_preprocessed_dataset(
203-
tokenizer,
209+
processor,
204210
dataset_config,
205211
split="train",
206212
)
207213
if not train_config.enable_fsdp or rank == 0:
208214
print(f"--> Training Set Length = {len(dataset_train)}")
209215

210216
dataset_val = get_preprocessed_dataset(
211-
tokenizer,
217+
processor,
212218
dataset_config,
213219
split="test",
214220
)
@@ -219,14 +225,15 @@ def main(**kwargs):
219225
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
220226

221227
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
222-
228+
print("length of dataset_train", len(dataset_train))
223229
# Create DataLoaders for the training and validation dataset
224230
train_dataloader = torch.utils.data.DataLoader(
225231
dataset_train,
226232
num_workers=train_config.num_workers_dataloader,
227233
pin_memory=True,
228234
**train_dl_kwargs,
229235
)
236+
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
230237

231238
eval_dataloader = None
232239
if train_config.run_validation:
@@ -241,6 +248,7 @@ def main(**kwargs):
241248
pin_memory=True,
242249
**val_dl_kwargs,
243250
)
251+
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
244252
if len(eval_dataloader) == 0:
245253
raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
246254
else:

src/llama_recipes/utils/config_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,4 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
104104
kwargs["collate_fn"] = default_data_collator
105105
else:
106106
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
107-
108107
return kwargs

src/llama_recipes/utils/train_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
118118
max_steps_reached = False # Flag to indicate max training steps reached
119119
# Start the training loop
120120
for epoch in range(train_config.num_epochs):
121+
print(f"Starting epoch {epoch}/{train_config.num_epochs}")
122+
print(f"train_config.max_train_step: {train_config.max_train_step}")
121123
# stop when the maximum number of training steps is reached
122124
if max_steps_reached:
123125
break
@@ -130,6 +132,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
130132
with profile(train_config,local_rank) as profile_context:
131133
for step, batch in enumerate(train_dataloader):
132134
total_train_steps += 1
135+
#print("batch: ", batch)
133136
# stop when the maximum number of training steps is reached
134137
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
135138
max_steps_reached = True
@@ -149,8 +152,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
149152
else:
150153
batch[key] = batch[key].to('cuda:0')
151154
with autocast():
155+
assert(next(model.parameters()).device == batch['input_ids'].device)
156+
#print("batch: ", batch)
152157
loss = model(**batch).loss
153158
loss = loss / gradient_accumulation_steps
159+
#print("loss",loss)
154160
if train_config.save_metrics:
155161
train_step_loss.append(loss.detach().float().item())
156162
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -171,6 +177,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
171177
pbar.update(1)
172178
else:
173179
# regular backpropagation when fp16 is not used
180+
#print("loss123",loss)
174181
loss.backward()
175182
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
176183
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
@@ -243,12 +250,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
243250
print(f"PEFT modules are saved in {train_config.output_dir} directory")
244251

245252
else:
246-
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
253+
if not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
247254

248255
save_model_checkpoint(
249256
model, optimizer, rank, train_config, epoch=epoch
250257
)
251-
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
258+
elif not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
252259
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
253260
print("=====================================================")
254261

0 commit comments

Comments
 (0)