Skip to content

Commit abeb12c

Browse files
fix: subclass Lora config from upstream peft.LoraConfig (#609)
* Updated LoraConfig to subclass from peft.LoraConfig Signed-off-by: romit <[email protected]> * Added some fields under custom dataclass to let is pass through HfArgumentParser Signed-off-by: romit <[email protected]> * Lint and fmt fixes Signed-off-by: romit <[email protected]> * Updated config utils Signed-off-by: romit <[email protected]> * Update comment in LoraConfig Signed-off-by: r0 <[email protected]> * Lint changes Signed-off-by: romit <[email protected]> * Updated comment Signed-off-by: romit <[email protected]> --------- Signed-off-by: romit <[email protected]> Signed-off-by: r0 <[email protected]> Co-authored-by: Dushyant Behl <[email protected]>
1 parent 2949a3a commit abeb12c

File tree

4 files changed

+119
-42
lines changed

4 files changed

+119
-42
lines changed

docs/tuning-techniques.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
## LoRA Tuning Example
2626

27-
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
27+
Set `peft_method` to `"lora"`. You can additionally pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
2828
```py
2929
# Args you can pass
3030
r: int =8
@@ -340,7 +340,7 @@ You can see details on a sample configuration of Accelerated GPTQ-LoRA [here](ht
340340

341341
To use GPTQ-LoRA technique, you can set the `quantized_lora_config` defined [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/acceleration_configs/quantized_lora_config.py). See the Notes section of FMS Acceleration doc [below](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/README.md#fms-acceleration) for usage. The only kernel we are supporting currently is `triton_v2`.
342342

343-
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21).
343+
In addition, LoRA tuning technique is required to be used, set `peft_method` to `"lora"` and pass any arguments from [LoraConfig](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig).
344344

345345
Example command to run:
346346

tuning/config/peft_config.py

Lines changed: 105 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# Standard
1616
from dataclasses import dataclass, field
1717
from enum import Enum
18-
from typing import List
18+
from typing import List, Optional
1919

2020
# Third Party
21+
from peft import LoraConfig as HFLoraConfig
2122
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config
2223

2324

@@ -40,49 +41,125 @@ def to_hf_config(self):
4041

4142

4243
@dataclass
43-
class LoraConfig:
44+
class LoraConfig(HFLoraConfig):
4445
"""
45-
This is the configuration class to store the configuration of a [`LoraModel`].
46+
This is the configuration class that extends peft.LoraConfig with a few defaults.
4647
4748
Args:
48-
r (`int`):
49-
Lora attention dimension (the "rank").
50-
target_modules (List[str]]):
51-
The names of the modules to apply the adapter to. \
52-
If this is specified, only the modules with the specified \
53-
names will be replaced. Please specify modules as per model architecture. \
54-
If the value is ["all-linear"], \
55-
then LORA selects all linear and Conv1D modules as per model architecture, \
56-
except for the output layer.
5749
lora_alpha (`int`):
5850
The alpha parameter for Lora scaling.
5951
lora_dropout (`float`):
6052
The dropout probability for Lora layers.
61-
bias (`str`):
62-
Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. \
63-
If 'all' or 'lora_only', the corresponding biases will be updated during training. \
64-
Be aware that this means that, even when disabling the adapters, the model \
65-
will not produce the same output as the base model would have without adaptation.
6653
"""
6754

68-
r: int = 8
6955
lora_alpha: int = 32
70-
target_modules: List[str] = field(
56+
lora_dropout: float = 0.05
57+
58+
# HACK: The following list of arguments listed below
59+
# is a fix which reduces the field annotation from
60+
# Optional[List[str], str] type to Optional[List[str]] type
61+
# This is done for compatibility with HFArgumentParser
62+
# Please see: https://github.com/huggingface/peft/issues/2798 for further explanation!
63+
target_modules: Optional[List[str]] = field(
7164
default=None,
7265
metadata={
73-
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
74-
completely match or "
75-
'end with one of the strings. If the value is ["all-linear"], \
76-
then LORA selects all linear and Conv1D '
77-
"modules except for the output layer."
66+
"help": (
67+
"List of module names or regex expression of the module names to replace with LoRA."
68+
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
69+
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
70+
"(if the model is a PreTrainedModel, the output layer excluded). "
71+
"If not specified, modules will be chosen according to the model architecture, "
72+
"If the architecture is not known, an error will be raised -- "
73+
"in this case, you should specify the target modules manually. "
74+
"To avoid targeting any modules (because you want to apply `target_parameters`) "
75+
", set `target_modules=[]`."
76+
),
7877
},
7978
)
80-
target_parameters: List[str] = field(
79+
exclude_modules: Optional[List[str]] = field(
8180
default=None,
82-
metadata={"help": "The names/regex of the parameters to apply LORA to"},
81+
metadata={
82+
"help": (
83+
"List of module names or regex expression of the module names to exclude from Lora."
84+
)
85+
},
8386
)
84-
bias = "none"
85-
lora_dropout: float = 0.05
87+
init_lora_weights: bool = field(
88+
default=True,
89+
metadata={
90+
"help": (
91+
"How to initialize the weights of the LoRA layers. "
92+
"Passing True (default) results in the default initialization from "
93+
"the reference implementation from "
94+
"Microsoft, with the LoRA B weight being set to 0. "
95+
"This means that without further training, "
96+
"the LoRA adapter will be a no-op. "
97+
"Setting the initialization to False leads to random initialization of "
98+
"LoRA A and B, meaning that LoRA is not a no-op before training; "
99+
"this setting is intended for debugging purposes."
100+
),
101+
},
102+
)
103+
layers_to_transform: Optional[list[int]] = field(
104+
default=None,
105+
metadata={
106+
"help": (
107+
"The layer indexes to transform, is this argument is specified, "
108+
"PEFT will transform only the layers indexes that are specified inside this list. "
109+
"If a single integer is passed, PEFT will transform only the layer at this index. "
110+
"This only works when target_modules is a list of str."
111+
)
112+
},
113+
)
114+
layers_pattern: Optional[list[str]] = field(
115+
default=None,
116+
metadata={
117+
"help": (
118+
"The layer pattern name, used only if `layers_to_transform` is different to None "
119+
"and if the layer pattern is not in the common layers pattern. "
120+
"This only works when target_modules is a list of str. "
121+
"This should target the `nn.ModuleList` of the "
122+
"model, which is often called `'layers'` or `'h'`."
123+
)
124+
},
125+
)
126+
trainable_token_indices: Optional[list[int]] = field(
127+
default=None,
128+
metadata={
129+
"help": (
130+
"Lets you specify which token indices to selectively fine-tune "
131+
"without requiring to re-train the "
132+
"whole embedding matrix using the `peft.TrainableTokensModel` method. "
133+
"You can specify token indices in two ways. "
134+
"Either you specify a list of indices which will then target the model's input "
135+
"embedding layer (or, if not found, `embed_tokens`). "
136+
"(Not supported yet) Alternatively, you can specify a dictionary "
137+
"where the key is the name of the embedding module "
138+
"and the values are the list of token indices, e.g. "
139+
"`{'embed_tokens': [0, 1, ...]}`. Note that training "
140+
"with FSDP requires `use_orig_params=True` to "
141+
"avoid issues with non-uniform `requires_grad`."
142+
)
143+
},
144+
)
145+
loftq_config: Optional[dict] = field(
146+
default_factory=dict,
147+
metadata={
148+
"help": (
149+
"The configuration of LoftQ. If this is passed, "
150+
"then LoftQ will be used to quantize the backbone "
151+
"weights and initialize Lora layers. Also set `init_lora_weights='loftq'` "
152+
"in this case."
153+
)
154+
},
155+
)
156+
157+
def __post_init__(self):
158+
# If target_modules is a single-element list, convert it into a plain string
159+
if self.target_modules == ["all-linear"]:
160+
self.target_modules = "all-linear"
161+
162+
super().__post_init__()
86163

87164

88165
@dataclass

tuning/sft_trainer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def train(
7171
data_args: configs.DataArguments,
7272
train_args: configs.TrainingArguments,
7373
peft_config: Optional[ # pylint: disable=redefined-outer-name
74-
Union[peft_config.LoraConfig, LoraConfig, peft_config.PromptTuningConfig]
74+
Union[LoraConfig, peft_config.PromptTuningConfig]
7575
] = None,
7676
quantization_config: Optional[peft_config.Mxfp4Config] = None,
7777
trainer_controller_args: TrainerControllerCallback = None,
@@ -92,8 +92,7 @@ def train(
9292
model_args: tuning.config.configs.ModelArguments
9393
data_args: tuning.config.configs.DataArguments
9494
train_args: tuning.config.configs.TrainingArguments
95-
peft_config: peft_config.LoraConfig for Lora tuning | \
96-
LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \
95+
peft_config: LoraConfig (peft.LoraConfig): for activated Lora (aLoRA) tuning | \
9796
peft_config.PromptTuningConfig for prompt tuning | \
9897
None for full fine tuning
9998
The peft configuration to pass to trainer
@@ -110,7 +109,8 @@ def train(
110109
tracker with automatically be added.
111110
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
112111
quantized_lora_config: tuning.config.acceleration_configs.QuantizedLoraConfig \
113-
Should be used in combination with peft_config.LoraConfig for Lora tuning \
112+
Should be used in combination with LoraConfig for Lora tuning \
113+
https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig \
114114
fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \
115115
Should be used in combination with quantized_lora_config. Also currently
116116
fused_lora and fast_kernels must used together (may change in future). \
@@ -845,9 +845,7 @@ def main():
845845
)
846846
sys.exit(INTERNAL_ERROR_EXIT_CODE)
847847

848-
if isinstance(
849-
tune_config, (peft_config.LoraConfig, LoraConfig)
850-
): # aLoraConfig subclasses LoraConfig
848+
if isinstance(tune_config, LoraConfig): # aLoraConfig subclasses LoraConfig
851849
try:
852850
if training_args.save_model_dir:
853851
# Write number of added tokens to artifacts

tuning/utils/config_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import pickle
2121

2222
# Third Party
23-
from peft import LoraConfig as HFLoraConfig
2423
from peft import PromptTuningConfig as HFPromptTuningConfig
2524

2625
# Local
@@ -112,10 +111,13 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path):
112111
alora_config.task_type = task_type
113112
hf_peft_config = alora_config
114113
elif isinstance(tuning_config, peft_config.LoraConfig):
115-
lora_config = asdict(tuning_config)
116-
if lora_config["target_modules"] == ["all-linear"]:
117-
lora_config["target_modules"] = "all-linear"
118-
hf_peft_config = HFLoraConfig(task_type=task_type, **lora_config)
114+
if getattr(tuning_config, "target_modules") == ["all-linear"]:
115+
setattr(tuning_config, "target_modules", "all-linear")
116+
117+
if getattr(tuning_config, "task_type") is None:
118+
setattr(tuning_config, "task_type", task_type)
119+
120+
hf_peft_config = tuning_config
119121
elif isinstance(tuning_config, peft_config.PromptTuningConfig):
120122
hf_peft_config = HFPromptTuningConfig(
121123
task_type=task_type,

0 commit comments

Comments
 (0)