Skip to content

Commit 2f063e6

Browse files
ENH: Extend the regex for rank/alpha pattern (#2419)
Supersedes #2382 Right now, the regex used to match the keys passed for rank_pattern and alpha_pattern requires that either: 1. The module name is identical to the key 2. The module name having a prefix and then ending on the key This is restrictive, since it doesn't allow to disambiguate between all cases. E.g. if we have a model with these attributes: - model.foo - model.bar.foo We cannot currently target just model.foo. (We can already target only model.bar.foo by passing "bar.foo" as a key to the rank_pattern / alpha_pattern dict). This PR makes it possible to pass "^foo" as a key. This way, model.bar.foo is not targeted, as the key does not start with "foo". As a general rule for users, if they intend to have a full match, they should pass the full name of the module preceded by a ^. This is the least ambigious way. When running the test case with the old code, all the test cases with ^ will fail, which is fine, since ^ was not working anyway. At the same time, all test cases not using ^ pass, which means they are backwards compatible.
1 parent 37266c1 commit 2f063e6

File tree

13 files changed

+401
-61
lines changed

13 files changed

+401
-61
lines changed

docs/source/developer_guides/lora.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,36 @@ Assuming the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a
239239
[Fewshot-Metamath-OrcaVicuna-Mistral-10B](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B) is an example of a model trained using this method on Mistral-7B expanded to 10B. The
240240
[adapter_config.json](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B/blob/main/adapter_config.json) shows a sample LoRA adapter config applying this method for fine-tuning.
241241

242+
### Fine grained control over ranks and alpha (scaling)
243+
244+
By default, all layers targeted with LoRA will have the same rank `r` and the same `lora_alpha` (which determines the LoRA scaling), depending on what was specified in the [`LoraConfig`]. In same cases, however, you may want to indicate different values for different layers. This is possible by passing the `rank_pattern` and `alpha_pattern` arguments to [`LoraConfig`]. These arguments should be dictionaries with the key being the layer name and the value being the rank/alpha value. The keys can be [regular expressesions](https://docs.python.org/3/library/re.html) (regex). All LoRA layers that are not explicitly mentioned in `rank_pattern` and `alpha_pattern` will take the default `r` and `lora_alpha` values.
245+
246+
To give an examples, let's assume that we have a model with the following structure:
247+
248+
```python
249+
>>> print(model)
250+
Outer(
251+
(foo): Linear(...)
252+
(module): Middle(
253+
(foo): Linear(...)
254+
(foobar): Linear(...)
255+
(module): Inner(
256+
(foo): Linear(...)
257+
(barfoo): Linear(...)
258+
)
259+
)
260+
)
261+
```
262+
263+
- `rank_pattern={"foo": 42}` will match all 3 `foo` layers. Neither `foobar` nor `barfoo` are matched.
264+
- `rank_pattern={"^foo": 42}` will only match the `foo` layer of the model, but neither `module.foo` nor `module.module.foo`. This is because the `^` means "start of string" when using regular expressions, and only `foo` starts with `"foo"`, the other layer names have prefixes.
265+
- `rank_pattern={"^module.foo": 42}` matches only `module.foo`, but not `module.module.foo`, for the same reason.
266+
- `rank_pattern={"module.foo": 42}` matches both `module.foo` and `module.module.foo`, but not `foo`.
267+
- `rank_pattern={"^foo": 42, "^module.module.foo": 55}` matches `foo` and `module.module.foo`, respectively, but not `module.foo`.
268+
- There is no need to indicate `$` to mark the end of the match, as this is added automatically by PEFT.
269+
270+
The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get fancy with regular expressions -- just pass the full name for each module with a different rank/alpha, preceded by the `^` prefix, and you should be good.
271+
242272
## Optimizers
243273

244274
LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.

src/peft/tuners/bone/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class BoneConfig(PeftConfig):
5151
layer at this index.
5252
layers_pattern (`str`):
5353
The layer pattern name, used only if `layers_to_transform` is different from `None`.
54-
rank_pattern (`dict`):
55-
The mapping from layer names or regexp expression to ranks which are different from the default rank
56-
specified by `r`.
5754
modules_to_save (`List[str]`):
5855
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
5956
"""

src/peft/tuners/hra/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ class HRAConfig(PeftConfig):
5353
layers_pattern (`Optional[Union[List[str], str]]`):
5454
The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the
5555
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
56-
rank_pattern (`dict`):
57-
The mapping from layer names or regexp expression to ranks which are different from the default rank
58-
specified by `r`.
5956
modules_to_save (`List[str]`):
6057
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
6158
"""

src/peft/tuners/loha/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ class LoHaConfig(LycorisConfig):
6060
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
6161
rank_pattern (`dict`):
6262
The mapping from layer names or regexp expression to ranks which are different from the default rank
63-
specified by `r`.
63+
specified by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
6464
alpha_pattern (`dict`):
6565
The mapping from layer names or regexp expression to alphas which are different from the default alpha
66-
specified by `alpha`.
66+
specified by `alpha`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
6767
modules_to_save (`Optional[List[str]]`):
6868
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
6969
"""

src/peft/tuners/loha/model.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
16-
from itertools import chain
1715
from typing import Dict, Type, Union
1816

1917
import torch
2018
from torch import nn
2119

2220
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
21+
from peft.utils.other import get_pattern_key
2322

2423
from .layer import Conv2d, Linear, LoHaLayer
2524

@@ -100,14 +99,11 @@ def _create_and_replace(
10099
"""
101100
A private method to create and replace the target module with the adapter module.
102101
"""
103-
104-
# Regexp matching - Find key which matches current target_name in patterns provided
105-
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
106-
target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name)
107-
102+
r_key = get_pattern_key(config.rank_pattern.keys(), current_key)
103+
alpha_key = get_pattern_key(config.alpha_pattern.keys(), current_key)
108104
kwargs = config.to_dict()
109-
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
110-
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
105+
kwargs["r"] = config.rank_pattern.get(r_key, config.r)
106+
kwargs["alpha"] = config.alpha_pattern.get(alpha_key, config.alpha)
111107

112108
if isinstance(target, LoHaLayer):
113109
target.update_layer(adapter_name, **kwargs)

src/peft/tuners/lokr/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ class LoKrConfig(LycorisConfig):
6666
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
6767
rank_pattern (`dict`):
6868
The mapping from layer names or regexp expression to ranks which are different from the default rank
69-
specified by `r`.
69+
specified by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
7070
alpha_pattern (`dict`):
7171
The mapping from layer names or regexp expression to alphas which are different from the default alpha
72-
specified by `alpha`.
72+
specified by `alpha`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
7373
modules_to_save (`Optional[List[str]]`):
7474
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
7575
"""

src/peft/tuners/lokr/model.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
16-
from itertools import chain
1715
from typing import Dict, Type, Union
1816

1917
import torch
2018
from torch import nn
2119

2220
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
21+
from peft.utils.other import get_pattern_key
2322

2423
from .layer import Conv2d, Linear, LoKrLayer
2524

@@ -101,14 +100,11 @@ def _create_and_replace(
101100
"""
102101
A private method to create and replace the target module with the adapter module.
103102
"""
104-
105-
# Regexp matching - Find key which matches current target_name in patterns provided
106-
pattern_keys = list(chain(config.rank_pattern.keys(), config.alpha_pattern.keys()))
107-
target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name)
108-
103+
r_key = get_pattern_key(config.rank_pattern.keys(), current_key)
104+
alpha_key = get_pattern_key(config.alpha_pattern.keys(), current_key)
109105
kwargs = config.to_dict()
110-
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
111-
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
106+
kwargs["r"] = config.rank_pattern.get(r_key, config.r)
107+
kwargs["alpha"] = config.alpha_pattern.get(alpha_key, config.alpha)
112108
kwargs["rank_dropout_scale"] = config.rank_dropout_scale
113109

114110
if isinstance(target, LoKrLayer):

src/peft/tuners/lora/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,10 @@ class LoraConfig(PeftConfig):
262262
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
263263
rank_pattern (`dict`):
264264
The mapping from layer names or regexp expression to ranks which are different from the default rank
265-
specified by `r`.
265+
specified by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
266266
alpha_pattern (`dict`):
267267
The mapping from layer names or regexp expression to alphas which are different from the default alpha
268-
specified by `lora_alpha`.
268+
specified by `lora_alpha`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
269269
megatron_config (`Optional[dict]`):
270270
The TransformerConfig arguments for Megatron. It is used to create LoRA's parallel linear layer. You can
271271
get it like this, `core_transformer_config_from_args(get_args())`, these two functions being from Megatron.
@@ -399,7 +399,7 @@ class LoraConfig(PeftConfig):
399399
metadata={
400400
"help": (
401401
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
402-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
402+
"For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`."
403403
)
404404
},
405405
)
@@ -408,7 +408,7 @@ class LoraConfig(PeftConfig):
408408
metadata={
409409
"help": (
410410
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. "
411-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
411+
"For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`."
412412
)
413413
},
414414
)

src/peft/tuners/lycoris_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class LycorisConfig(PeftConfig):
4242
metadata={
4343
"help": (
4444
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
45-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
45+
"For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`."
4646
)
4747
},
4848
)
@@ -51,7 +51,7 @@ class LycorisConfig(PeftConfig):
5151
metadata={
5252
"help": (
5353
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `alpha`. "
54-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
54+
"For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`."
5555
)
5656
},
5757
)

src/peft/tuners/oft/config.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ class OFTConfig(PeftConfig):
5757
layers_pattern (`Optional[Union[List[str], str]]`):
5858
The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the
5959
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
60-
rank_pattern (`dict`):
61-
The mapping from layer names or regexp expression to ranks which are different from the default rank
62-
specified by `r`.
6360
modules_to_save (`List[str]`):
6461
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
6562
coft (`bool`):
@@ -147,26 +144,6 @@ class OFTConfig(PeftConfig):
147144
default=False,
148145
metadata={"help": "Whether to share the OFT parameters between blocks or not."},
149146
)
150-
rank_pattern: Optional[dict] = field(
151-
default_factory=dict,
152-
metadata={
153-
"help": (
154-
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
155-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
156-
"Important: the rank pattern won't be applied to the layers after 0.12.1.dev0!"
157-
)
158-
},
159-
)
160-
alpha_pattern: Optional[dict] = field(
161-
default_factory=dict,
162-
metadata={
163-
"help": (
164-
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `alpha`. "
165-
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
166-
"Important: the alpha pattern won't be applied to the layers after 0.12.1.dev0!"
167-
)
168-
},
169-
)
170147

171148
def __post_init__(self):
172149
super().__post_init__()

0 commit comments

Comments
 (0)