Skip to content

Commit 8a11b48

Browse files
committed
lora+fsdp not working
1 parent 79dbe05 commit 8a11b48

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

recipes/quickstart/finetuning/datasets/vqa_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ def tokenize_dialogs(dialogs, images, processor):
3939
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
4040
else:
4141
last_idx = idx+1
42-
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
42+
# Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
4343
assistant_header_seq = [128006, 78191, 128007]
4444
labels = replace_target(assistant_header_seq,labels)
45+
# Mask the padding token and image token 128256
46+
for i in range(len(labels)):
47+
if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index
48+
labels[i] = -100
4549
label_list.append(labels)
4650
batch["labels"] = torch.tensor(label_list)
4751
tokenizer_length = len(processor.tokenizer)

src/llama_recipes/finetuning.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def main(**kwargs):
137137
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
138138
processor.tokenizer.padding_side='right'
139139
else:
140+
is_vision = False
140141
model = LlamaForCausalLM.from_pretrained(
141142
train_config.model_name,
142143
quantization_config=bnb_config,
@@ -188,23 +189,20 @@ def main(**kwargs):
188189
freeze_transformer_layers(model, train_config.num_freeze_layers)
189190

190191
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
191-
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer])
192+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
193+
# if is_vision:
194+
# my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
195+
# else:
196+
# my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
192197
print("FSDP is enabled",my_auto_wrapping_policy)
193198
device_id = 0
194199
if is_xpu_available():
195200
device_id = torch.xpu.current_device()
196201
elif torch.cuda.is_available():
197202
device_id = torch.cuda.current_device()
198-
if train_config.use_peft:
199-
wrapping_policy = my_auto_wrapping_policy
200-
else:
201-
if is_vision:
202-
wrapping_policy = ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer])
203-
else:
204-
wrapping_policy = ModuleWrapPolicy([LlamaDecoderLayer])
205203
model = FSDP(
206204
model,
207-
auto_wrap_policy= wrapping_policy,
205+
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
208206
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
209207
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
210208
sharding_strategy=fsdp_config.sharding_strategy,

src/llama_recipes/policies/wrapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import functools
55

66
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
7+
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
8+
79
from torch.distributed.fsdp.wrap import (
810
transformer_auto_wrap_policy,
911
size_based_auto_wrap_policy,
@@ -27,6 +29,7 @@ def get_llama_wrapper():
2729
transformer_auto_wrap_policy,
2830
transformer_layer_cls={
2931
LlamaDecoderLayer,
32+
CLIPEncoderLayer
3033
},
3134
)
3235

0 commit comments

Comments
 (0)