Skip to content
54 changes: 46 additions & 8 deletions src/peft/tuners/loha/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Literal, Optional, Union

from peft.tuners.lycoris_utils import LycorisConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -49,9 +49,12 @@ class LoHaConfig(LycorisConfig):
The names of the modules to not apply the adapter. When passing a string, a regex match will be performed.
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
module ends with any of the passed strings.
init_weights (`bool`):
Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is
discouraged.
init_weights (`Union[bool, Literal["abba"]]`):
How to initialize the weights of the LoHa layers. Pass `True` (default) for default initialization, `False`
for random initialization, or `'abba'` for ABBA initialization which approximates pretrained weights using
SVD decomposition, potentially improving training stability and convergence. Based on the ABBA paper:
https://arxiv.org/pdf/2505.14238 See https://github.com/huggingface/peft/issues/2587 for implementation
details.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
Expand All @@ -69,7 +72,24 @@ class LoHaConfig(LycorisConfig):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
"""

r: int = field(default=8, metadata={"help": "LoHa rank"})
r: int = field(
default=8,
metadata={
"help": (
"LoHa rank for the first Hadamard component. For standard LoHa, both components use this rank. "
"For asymmetric ranks, use r2 to specify a different rank for the second component."
)
},
)
r2: Optional[int] = field(
default=None,
metadata={
"help": (
"Rank for the second Hadamard component (w2a @ w2b). "
"If not specified, defaults to r (symmetric ranks)."
)
},
)
alpha: int = field(default=8, metadata={"help": "LoHa alpha"})
rank_dropout: float = field(
default=0.0, metadata={"help": "The dropout probability for rank dimension during training"}
Expand All @@ -86,6 +106,19 @@ class LoHaConfig(LycorisConfig):
)
},
)
use_khatri_rao: Union[bool, Literal["auto"]] = field(
default="auto",
metadata={
"help": (
"Use Khatri-Rao product optimization to reduce memory overhead. "
"This reparameterizes the update using Khatri-Rao product instead of "
"constructing full B1A1 and B2A2 matrices, reducing memory footprint "
"to be similar to LoRA while maintaining expressiveness. "
"When set to 'auto' (default), it is enabled for ABBA initialization (per paper recommendation) "
"and disabled for standard LoHa. Set to True or False to explicitly control this behavior."
)
},
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
Expand All @@ -98,12 +131,17 @@ class LoHaConfig(LycorisConfig):
default=None,
metadata={"help": "List of module names or regex expression of the module names to exclude from LoHa."},
)
init_weights: bool = field(
init_weights: Union[bool, Literal["abba"]] = field(
default=True,
metadata={
"help": (
"Whether to initialize the weights of the LoHa layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
"How to initialize the weights of the LoHa layers. "
"Pass `True` (default) for default initialization (zeros for one matrix), "
"`False` for random initialization, or `'abba'` for ABBA initialization "
"which initializes weights to approximate the pretrained weights using SVD decomposition. "
"ABBA initialization can improve training stability and convergence. "
"Based on the ABBA paper: https://arxiv.org/pdf/2505.14238. "
"See https://github.com/huggingface/peft/issues/2587 for implementation details."
),
},
)
Expand Down
Loading