Skip to content

Commit 5f11aeb

Browse files
Disable prefix tuning and limit llama adapter (meta-llama#482)
2 parents 3d5c701 + 4b93dc6 commit 5f11aeb

File tree

6 files changed

+155
-67
lines changed

6 files changed

+155
-67
lines changed

recipes/finetuning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ It lets us specify the training settings for everything from `model_name` to `da
7070

7171
* [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.
7272

73-
* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified.
73+
* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified. We currently support LoRA and Llama-Adapter. Please note that LoRA is the only technique which is supported in combination with FSDP.
7474

7575
* [FSDP config file](../../src/llama_recipes/configs/fsdp.py) provides FSDP settings such as:
7676

src/llama_recipes/configs/peft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class llama_adapter_config:
2020
adapter_layers: int= 30
2121
task_type: str= "CAUSAL_LM"
2222

23+
#CAUTION prefix tuning is currently not supported
2324
@dataclass
2425
class prefix_config:
2526
num_virtual_tokens: int=30
26-
task_type: str= "CAUSAL_LM"
27+
task_type: str= "CAUSAL_LM"

src/llama_recipes/configs/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class train_config:
2929
mixed_precision: bool=True
3030
val_batch_size: int=1
3131
dataset = "samsum_dataset"
32-
peft_method: str = "lora" # None,llama_adapter, prefix
32+
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
3333
use_peft: bool=False
3434
output_dir: str = "PATH/to/save/PEFT/model"
3535
freeze_layers: bool = False

src/llama_recipes/utils/config_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ def generate_peft_config(train_config, kwargs):
4545
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
4646
names = tuple(c.__name__.rstrip("_config") for c in configs)
4747

48-
assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
48+
if train_config.peft_method not in names:
49+
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
50+
51+
if train_config.peft_method == "prefix":
52+
raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)")
53+
54+
if train_config.enable_fsdp and train_config.peft_method == "llama_adapter":
55+
raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)")
4956

5057
config = configs[names.index(train_config.peft_method)]()
5158

src/llama_recipes/utils/fsdp_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
88

99
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
1010

11-
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
12-
1311
def lambda_policy_fn(module):
1412
if (
1513
len(list(module.named_children())) == 0
@@ -23,13 +21,7 @@ def lambda_policy_fn(module):
2321
transformer_wrap_policy = functools.partial(
2422
transformer_auto_wrap_policy,
2523
transformer_layer_cls=(
26-
PrefixEncoder,
27-
PromptEncoder,
28-
PromptEmbedding,
2924
transformer_layer_name,
30-
# FullyShardedDataParallelPlugin.get_module_class_from_name(
31-
# model, transformer_layer_name
32-
# ),
3325
),
3426
)
3527

tests/test_finetuning.py

Lines changed: 143 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,56 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4-
import pytest
5-
from pytest import approx
4+
import os
65
from unittest.mock import patch
76

7+
import pytest
8+
89
import torch
10+
from llama_recipes.data.sampler import LengthBasedBatchSampler
11+
12+
from llama_recipes.finetuning import main
13+
from pytest import approx
914
from torch.optim import AdamW
1015
from torch.utils.data.dataloader import DataLoader
1116
from torch.utils.data.sampler import BatchSampler
1217

13-
from llama_recipes.finetuning import main
14-
from llama_recipes.data.sampler import LengthBasedBatchSampler
15-
1618

1719
def get_fake_dataset():
18-
return [{
19-
"input_ids":[1],
20-
"attention_mask":[1],
21-
"labels":[1],
22-
}]
23-
24-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
25-
@patch('llama_recipes.finetuning.train')
26-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
27-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
28-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
29-
@patch('llama_recipes.finetuning.optim.AdamW')
30-
@patch('llama_recipes.finetuning.StepLR')
20+
return [
21+
{
22+
"input_ids": [1],
23+
"attention_mask": [1],
24+
"labels": [1],
25+
}
26+
]
27+
28+
29+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
30+
@patch("llama_recipes.finetuning.train")
31+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
32+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
33+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
34+
@patch("llama_recipes.finetuning.optim.AdamW")
35+
@patch("llama_recipes.finetuning.StepLR")
3136
@pytest.mark.parametrize("cuda_is_available", [True, False])
32-
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
37+
def test_finetuning_no_validation(
38+
step_lr,
39+
optimizer,
40+
get_dataset,
41+
tokenizer,
42+
get_model,
43+
train,
44+
cuda,
45+
cuda_is_available,
46+
):
3347
kwargs = {"run_validation": False}
3448

3549
get_dataset.return_value = get_fake_dataset()
3650
cuda.return_value = cuda_is_available
3751

52+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
53+
3854
main(**kwargs)
3955

4056
assert train.call_count == 1
@@ -53,20 +69,31 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
5369
assert get_model.return_value.to.call_count == 0
5470

5571

56-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
57-
@patch('llama_recipes.finetuning.train')
58-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
59-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
60-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
61-
@patch('llama_recipes.finetuning.optim.AdamW')
62-
@patch('llama_recipes.finetuning.StepLR')
72+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
73+
@patch("llama_recipes.finetuning.train")
74+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
75+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
76+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
77+
@patch("llama_recipes.finetuning.optim.AdamW")
78+
@patch("llama_recipes.finetuning.StepLR")
6379
@pytest.mark.parametrize("cuda_is_available", [True, False])
64-
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
80+
def test_finetuning_with_validation(
81+
step_lr,
82+
optimizer,
83+
get_dataset,
84+
tokenizer,
85+
get_model,
86+
train,
87+
cuda,
88+
cuda_is_available,
89+
):
6590
kwargs = {"run_validation": True}
6691

6792
get_dataset.return_value = get_fake_dataset()
6893
cuda.return_value = cuda_is_available
6994

95+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
96+
7097
main(**kwargs)
7198

7299
assert train.call_count == 1
@@ -83,22 +110,36 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
83110
else:
84111
assert get_model.return_value.to.call_count == 0
85112

86-
@patch('llama_recipes.finetuning.torch.cuda.is_available')
87-
@patch('llama_recipes.finetuning.train')
88-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
89-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
90-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
91-
@patch('llama_recipes.finetuning.generate_peft_config')
92-
@patch('llama_recipes.finetuning.get_peft_model')
93-
@patch('llama_recipes.finetuning.optim.AdamW')
94-
@patch('llama_recipes.finetuning.StepLR')
113+
114+
@patch("llama_recipes.finetuning.torch.cuda.is_available")
115+
@patch("llama_recipes.finetuning.train")
116+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
117+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
118+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
119+
@patch("llama_recipes.finetuning.generate_peft_config")
120+
@patch("llama_recipes.finetuning.get_peft_model")
121+
@patch("llama_recipes.finetuning.optim.AdamW")
122+
@patch("llama_recipes.finetuning.StepLR")
95123
@pytest.mark.parametrize("cuda_is_available", [True, False])
96-
def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
124+
def test_finetuning_peft_lora(
125+
step_lr,
126+
optimizer,
127+
get_peft_model,
128+
gen_peft_config,
129+
get_dataset,
130+
tokenizer,
131+
get_model,
132+
train,
133+
cuda,
134+
cuda_is_available,
135+
):
97136
kwargs = {"use_peft": True}
98137

99138
get_dataset.return_value = get_fake_dataset()
100139
cuda.return_value = cuda_is_available
101140

141+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
142+
102143
main(**kwargs)
103144

104145
if cuda_is_available:
@@ -110,21 +151,64 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
110151
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
111152

112153

113-
@patch('llama_recipes.finetuning.train')
114-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
115-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
116-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
117-
@patch('llama_recipes.finetuning.get_peft_model')
118-
@patch('llama_recipes.finetuning.StepLR')
119-
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
120-
kwargs = {"weight_decay": 0.01}
154+
@patch("llama_recipes.finetuning.get_peft_model")
155+
@patch("llama_recipes.finetuning.setup")
156+
@patch("llama_recipes.finetuning.train")
157+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
158+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
159+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
160+
def test_finetuning_peft_llama_adapter(
161+
get_dataset, tokenizer, get_model, train, setup, get_peft_model
162+
):
163+
kwargs = {
164+
"use_peft": True,
165+
"peft_method": "llama_adapter",
166+
"enable_fsdp": True,
167+
}
121168

122169
get_dataset.return_value = get_fake_dataset()
123170

124-
model = mocker.MagicMock(name="Model")
125-
model.parameters.return_value = [torch.ones(1,1)]
171+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
172+
173+
os.environ["RANK"] = "0"
174+
os.environ["LOCAL_RANK"] = "0"
175+
os.environ["WORLD_SIZE"] = "1"
176+
os.environ["MASTER_ADDR"] = "localhost"
177+
os.environ["MASTER_PORT"] = "12345"
178+
179+
with pytest.raises(
180+
RuntimeError,
181+
match="Llama_adapter is currently not supported in combination with FSDP",
182+
):
183+
main(**kwargs)
184+
185+
GET_ME_OUT = "Get me out of here"
186+
get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
187+
188+
kwargs["enable_fsdp"] = False
189+
190+
with pytest.raises(
191+
RuntimeError,
192+
match=GET_ME_OUT,
193+
):
194+
main(**kwargs)
195+
126196

127-
get_model.return_value = model
197+
@patch("llama_recipes.finetuning.train")
198+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
199+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
200+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
201+
@patch("llama_recipes.finetuning.get_peft_model")
202+
@patch("llama_recipes.finetuning.StepLR")
203+
def test_finetuning_weight_decay(
204+
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
205+
):
206+
kwargs = {"weight_decay": 0.01}
207+
208+
get_dataset.return_value = get_fake_dataset()
209+
210+
get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
211+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
128212

129213
main(**kwargs)
130214

@@ -139,17 +223,21 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
139223
assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
140224

141225

142-
@patch('llama_recipes.finetuning.train')
143-
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
144-
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
145-
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
146-
@patch('llama_recipes.finetuning.optim.AdamW')
147-
@patch('llama_recipes.finetuning.StepLR')
148-
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
226+
@patch("llama_recipes.finetuning.train")
227+
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
228+
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
229+
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
230+
@patch("llama_recipes.finetuning.optim.AdamW")
231+
@patch("llama_recipes.finetuning.StepLR")
232+
def test_batching_strategy(
233+
step_lr, optimizer, get_dataset, tokenizer, get_model, train
234+
):
149235
kwargs = {"batching_strategy": "packing"}
150236

151237
get_dataset.return_value = get_fake_dataset()
152238

239+
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
240+
153241
main(**kwargs)
154242

155243
assert train.call_count == 1

0 commit comments

Comments
 (0)