Skip to content

Commit bb990be

Browse files
committed
not working, need create dataloader function
1 parent c38cccb commit bb990be

File tree

4 files changed

+64
-48
lines changed

4 files changed

+64
-48
lines changed

recipes/quickstart/finetuning/datasets/vqa_dataset.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
from datasets import load_dataset
77
import itertools
8+
import torch
89
# check system prompt token seq or user prompt token seq is in the current token list
910
def check_header(targets,seq):
1011
for i in range(len(seq)-3):
@@ -20,13 +21,8 @@ def tokenize_dialog(dialog, images, processor):
2021
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
2122
text_prompt = processor.apply_chat_template(dialog)
2223
#print("text_prompt",text_prompt)
23-
batch = processor(images=images, text=text_prompt)
24-
dialog_tokens = batch["input_ids"].tolist()[0]
25-
#print("dialog_tokens",dialog_tokens)
26-
#print("dialog_tokens",dialog_tokens)
27-
attention_mask = batch["attention_mask"].tolist()[0]
28-
#print("attention_mask",attention_mask)
29-
labels = copy.copy(dialog_tokens)
24+
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
25+
labels = copy.copy(batch["input_ids"].tolist()[0])
3026
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
3127
last_idx = 0
3228
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
@@ -43,25 +39,34 @@ def tokenize_dialog(dialog, images, processor):
4339
assistant_header_seq = [128006, 78191, 128007]
4440
labels = replace_target(assistant_header_seq,labels)
4541
#print("labels",labels)
42+
# print("pixel_values .shape",batch["pixel_values"].shape)
43+
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
4644

47-
48-
combined_tokens = {
49-
# "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
50-
# "labels": list(itertools.chain(*(t for t in labels_tokens))),
51-
"input_ids": dialog_tokens,
52-
"labels": labels,
53-
"attention_mask": [1]*len(dialog_tokens),
54-
"pixel_values": batch["pixel_values"].tolist()[0],
55-
"image_sizes": batch["image_sizes"].tolist()[0]
56-
}
45+
batch["labels"] = torch.tensor(labels)
46+
#pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
47+
batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
48+
# pixel_values .shape torch.Size([1, 4, 3, 560, 560])
49+
print("pixel_values .shape",batch["pixel_values"].shape)
50+
# exit()
51+
# combined_tokens = {
52+
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
53+
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
54+
# "input_ids": dialog_tokens,
55+
# "labels": labels,
56+
# "attention_mask": [1]*len(dialog_tokens),
57+
# "pixel_values": batch["pixel_values"],
58+
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
59+
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
60+
# "cross_attention_mask": batch["cross_attention_mask"]
61+
# }
5762
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
5863
# labels = list(itertools.chain(*(t for t in labels_tokens))),
5964
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
6065
# pixel_values = batch["pixel_values"],
6166
# image_sizes = batch["image_sizes"]
6267
# print("combined_tokens",combined_tokens[image_sizes])
6368

64-
return combined_tokens
69+
return batch
6570
def image_tokenize(sample, processor):
6671
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
6772
images,sample_text = sample["images"],sample["messages"]

src/llama_recipes/finetuning.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@
2626
BitsAndBytesConfig,
2727
LlamaForCausalLM,
2828
LlamaConfig,
29-
AutoConfig,
30-
AutoModel,
31-
LlavaNextForConditionalGeneration,
32-
LlavaNextProcessor
33-
29+
AutoProcessor,
30+
MllamaForConditionalGeneration
3431
)
3532
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
3633
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
@@ -126,20 +123,32 @@ def main(**kwargs):
126123

127124
# Load the pre-trained model and setup its configuration
128125
use_cache = False if train_config.enable_fsdp else None
129-
model = LlavaNextForConditionalGeneration.from_pretrained(
126+
if "11B" in train_config.model_name or "90B" in train_config.model_name:
127+
is_vision = True
128+
model = MllamaForConditionalGeneration.from_pretrained(
130129
train_config.model_name,
131130
quantization_config=bnb_config,
132131
#use_cache=use_cache,
133132
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
134133
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
135134
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
136135
)
136+
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
137+
processor.tokenizer.padding_side='right'
138+
else:
139+
model = LlamaForCausalLM.from_pretrained(
140+
train_config.model_name,
141+
quantization_config=bnb_config,
142+
use_cache=use_cache,
143+
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
144+
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
145+
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
146+
)
137147

138148
# Load the tokenizer and add special tokens
139149
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
140150
tokenizer.pad_token_id = tokenizer.eos_token_id
141-
processor = LlavaNextProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
142-
processor.tokenizer.padding_side='right'
151+
143152
# If there is a mismatch between tokenizer vocab size and embedding matrix,
144153
# throw a warning and then expand the embedding matrix
145154
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -183,18 +192,16 @@ def main(**kwargs):
183192
device_id = torch.xpu.current_device()
184193
elif torch.cuda.is_available():
185194
device_id = torch.cuda.current_device()
186-
# print(dir(model))
187-
# for layer in model.named_children():
188-
# print(f"Layer: {layer}")
189-
190-
# layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
191-
# for name, param in layernorm.named_parameters():
192-
# print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
193-
# exit()
195+
if train_config.use_peft:
196+
wrapping_policy = my_auto_wrapping_policy
197+
else:
198+
if is_vision:
199+
wrapping_policy = ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer])
200+
else:
201+
wrapping_policy = ModuleWrapPolicy([LlamaDecoderLayer])
194202
model = FSDP(
195203
model,
196-
auto_wrap_policy= ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]),
197-
#auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
204+
auto_wrap_policy= wrapping_policy,
198205
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
199206
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
200207
sharding_strategy=fsdp_config.sharding_strategy,
@@ -205,10 +212,9 @@ def main(**kwargs):
205212
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
206213
if train_config.low_cpu_fsdp and rank != 0 else None,
207214
)
208-
#print(model)
209215
if fsdp_config.fsdp_activation_checkpointing:
210216
model.enable_input_require_grads()
211-
model.gradient_checkpointing_enable()
217+
#model.gradient_checkpointing_enable()
212218
apply_fsdp_checkpointing(model)
213219
elif not train_config.quantization and not train_config.enable_fsdp:
214220
if is_xpu_available():
@@ -217,23 +223,23 @@ def main(**kwargs):
217223
model.to("cuda")
218224

219225
dataset_config = generate_dataset_config(train_config, kwargs)
226+
if is_vision:
227+
dataset_processer = processor
228+
else:
229+
dataset_processer = tokenizer
230+
231+
# Load and preprocess the dataset for training and validation
220232

221-
# Load and preprocess the dataset for training and validation
222-
# dataset_train = get_preprocessed_dataset(
223-
# processor,
224-
# dataset_config,
225-
# split="train",
226-
# )
227233
dataset_train = get_preprocessed_dataset(
228-
processor,
234+
dataset_processer,
229235
dataset_config,
230236
split="train",
231237
)
232238
if not train_config.enable_fsdp or rank == 0:
233239
print(f"--> Training Set Length = {len(dataset_train)}")
234240

235241
dataset_val = get_preprocessed_dataset(
236-
processor,
242+
dataset_processer,
237243
dataset_config,
238244
split="test",
239245
)

src/llama_recipes/utils/config_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def generate_dataset_config(train_config, kwargs):
7575
return dataset_config
7676

7777

78-
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
78+
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode,collate_fn=None):
7979
kwargs = {}
8080
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
8181
if train_config.batching_strategy == "padding":
@@ -89,7 +89,10 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
8989
)
9090
else:
9191
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
92-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
92+
if not collate_fn:
93+
kwargs["collate_fn"] = collate_fn
94+
else:
95+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
9396
elif train_config.batching_strategy == "packing":
9497
if train_config.enable_fsdp:
9598
kwargs["sampler"] = DistributedSampler(

src/llama_recipes/utils/train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
154154
with autocast():
155155
assert(next(model.parameters()).device == batch['input_ids'].device)
156156
#print("batch: ", batch)
157+
pixel_values = batch['pixel_values']
158+
print("pixel_values.shape input",pixel_values.shape)
157159
loss = model(**batch).loss
158160
loss = loss / gradient_accumulation_steps
159161
#print("loss",loss)

0 commit comments

Comments
 (0)