Skip to content

Commit 091d58d

Browse files
committed
Disable prefix tuning as its currently not supported; Limit llama_adapter usage to non-FSDP only
1 parent 14e4b05 commit 091d58d

File tree

5 files changed

+15
-12
lines changed

5 files changed

+15
-12
lines changed

recipes/finetuning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ It lets us specify the training settings for everything from `model_name` to `da
7070

7171
* [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.
7272

73-
* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified.
73+
* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified. We currently support LoRA and LLaMA-Adapter. Please note that LoRA is the only technique which is supported in combination with FSDP.
7474

7575
* [FSDP config file](../../src/llama_recipes/configs/fsdp.py) provides FSDP settings such as:
7676

src/llama_recipes/configs/peft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class llama_adapter_config:
2020
adapter_layers: int= 30
2121
task_type: str= "CAUSAL_LM"
2222

23+
#CAUTION prefix tuning is currently not supported
2324
@dataclass
2425
class prefix_config:
2526
num_virtual_tokens: int=30
26-
task_type: str= "CAUSAL_LM"
27+
task_type: str= "CAUSAL_LM"

src/llama_recipes/configs/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class train_config:
2929
mixed_precision: bool=True
3030
val_batch_size: int=1
3131
dataset = "samsum_dataset"
32-
peft_method: str = "lora" # None,llama_adapter, prefix
32+
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
3333
use_peft: bool=False
3434
output_dir: str = "PATH/to/save/PEFT/model"
3535
freeze_layers: bool = False

src/llama_recipes/utils/config_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,17 @@ def generate_peft_config(train_config, kwargs):
4545
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
4646
names = tuple(c.__name__.rstrip("_config") for c in configs)
4747

48-
assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
48+
assert (
49+
train_config.peft_method in names
50+
), f"Peft config not found: {train_config.peft_method}"
51+
52+
assert (
53+
train_config.peft_method != "prefix"
54+
), "PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)"
55+
if train_config.enable_fsdp:
56+
assert (
57+
train_config.peft_method != "llama_adapter"
58+
), "Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)"
4959

5060
config = configs[names.index(train_config.peft_method)]()
5161

src/llama_recipes/utils/fsdp_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
88

99
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
1010

11-
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
12-
1311
def lambda_policy_fn(module):
1412
if (
1513
len(list(module.named_children())) == 0
@@ -23,13 +21,7 @@ def lambda_policy_fn(module):
2321
transformer_wrap_policy = functools.partial(
2422
transformer_auto_wrap_policy,
2523
transformer_layer_cls=(
26-
PrefixEncoder,
27-
PromptEncoder,
28-
PromptEmbedding,
2924
transformer_layer_name,
30-
# FullyShardedDataParallelPlugin.get_module_class_from_name(
31-
# model, transformer_layer_name
32-
# ),
3325
),
3426
)
3527

0 commit comments

Comments
 (0)