Skip to content

Commit 1a76080

Browse files
committed
lora+fsdp working
1 parent 8a11b48 commit 1a76080

File tree

3 files changed

+23
-28
lines changed

3 files changed

+23
-28
lines changed

src/llama_recipes/finetuning.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
MllamaForConditionalGeneration
3131
)
3232
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
33-
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
33+
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
34+
3435
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
3536
from llama_recipes.configs import train_config as TRAIN_CONFIG
3637
from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
@@ -129,7 +130,6 @@ def main(**kwargs):
129130
model = MllamaForConditionalGeneration.from_pretrained(
130131
train_config.model_name,
131132
quantization_config=bnb_config,
132-
#use_cache=use_cache,
133133
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
134134
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
135135
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -146,7 +146,7 @@ def main(**kwargs):
146146
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
147147
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
148148
)
149-
149+
print(model)
150150
# Load the tokenizer and add special tokens
151151
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
152152
tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -189,11 +189,7 @@ def main(**kwargs):
189189
freeze_transformer_layers(model, train_config.num_freeze_layers)
190190

191191
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
192-
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
193-
# if is_vision:
194-
# my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
195-
# else:
196-
# my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
192+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
197193
print("FSDP is enabled",my_auto_wrapping_policy)
198194
device_id = 0
199195
if is_xpu_available():
@@ -222,7 +218,8 @@ def main(**kwargs):
222218
model.to("xpu:0")
223219
elif torch.cuda.is_available():
224220
model.to("cuda")
225-
221+
print("-------------------")
222+
print("FSDP model", model)
226223
dataset_config = generate_dataset_config(train_config, kwargs)
227224
if is_vision:
228225
dataset_processer = processor
@@ -248,7 +245,10 @@ def main(**kwargs):
248245
print(f"--> Validation Set Length = {len(dataset_val)}")
249246

250247
if train_config.batching_strategy == "packing":
251-
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
248+
if is_vision:
249+
raise ValueError("Packing is not supported for vision datasets")
250+
else:
251+
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
252252

253253
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
254254
print("length of dataset_train", len(dataset_train))
@@ -268,7 +268,10 @@ def main(**kwargs):
268268
eval_dataloader = None
269269
if train_config.run_validation:
270270
if train_config.batching_strategy == "packing":
271-
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
271+
if is_vision:
272+
raise ValueError("Packing is not supported for vision datasets")
273+
else:
274+
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
272275

273276
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
274277
if custom_data_collator:

src/llama_recipes/policies/wrapping.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55

66
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
7-
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
7+
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
88

99
from torch.distributed.fsdp.wrap import (
1010
transformer_auto_wrap_policy,
@@ -27,10 +27,7 @@ def get_llama_wrapper():
2727

2828
llama_auto_wrap_policy = functools.partial(
2929
transformer_auto_wrap_policy,
30-
transformer_layer_cls={
31-
LlamaDecoderLayer,
32-
CLIPEncoderLayer
33-
},
30+
transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
3431
)
3532

3633
return llama_auto_wrap_policy

src/llama_recipes/utils/fsdp_utils.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,14 @@ def lambda_policy_fn(module):
1616
):
1717
return True
1818
return False
19-
transformer_wrap_policies = []
19+
2020
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
21-
for transformer_layer_name in transformer_layer_names:
22-
23-
transformer_wrap_policy = functools.partial(
24-
transformer_auto_wrap_policy,
25-
transformer_layer_cls=(
26-
transformer_layer_name,
27-
),
28-
)
29-
transformer_wrap_policies.append(transformer_wrap_policy)
30-
policies = transformer_wrap_policies
31-
auto_wrap_policy = functools.partial(_or_policy, policies=policies)
21+
transformer_wrap_policy = functools.partial(
22+
transformer_auto_wrap_policy,
23+
transformer_layer_cls=set(transformer_layer_names)
24+
)
25+
26+
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
3227
return auto_wrap_policy
3328

3429

0 commit comments

Comments
 (0)