Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,19 @@ class LoraConfig(PeftConfig):
)
},
)
use_wora: bool = field(
default=False,
metadata={
"help": (
"Enable 'Weighted-Direction Low-Rank Adaptation' (WoRA). WoRA extends DoRA by adding learnable "
"scalar parameters (alpha, beta) that weight the base weights and low-rank update before normalization, "
"allowing the model to learn the optimal trade-off between pretrained knowledge and task-specific adaptations. "
"See <a href='https://arxiv.org/pdf/2404.10292'>the WoRA paper</a> for details. "
"Note: WoRA cannot be used simultaneously with DoRA or QALoRA. WoRA supports linear, embedding, and "
"convolutional layers. Like DoRA, it is recommended to merge weights for inference."
)
},
)
alora_invocation_tokens: Optional[list[int]] = field(
default=None,
metadata={
Expand Down Expand Up @@ -714,6 +727,19 @@ def __post_init__(self):
if self.use_dora and self.megatron_config:
raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.")

# Check WoRA conflicts
if self.use_wora and self.use_dora:
raise ValueError("Cannot use both WoRA and DoRA simultaneously. Please set one to False.")

if self.use_wora and self.use_qalora:
raise ValueError("WoRA with QALoRA is not supported yet. Please set one to False.")

if self.use_wora and self.alora_invocation_tokens is not None:
raise ValueError("WoRA with aLoRA is not supported. Please set use_wora=False or alora_invocation_tokens=None.")

if self.use_wora and self.megatron_config is not None:
raise ValueError("WoRA does not support megatron_core, please set `use_wora=False`.")

# handle init_lora_weights and loftq_config
if self.init_lora_weights == "loftq":
import importlib
Expand Down Expand Up @@ -751,6 +777,8 @@ def __post_init__(self):
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")
if self.use_wora:
raise ValueError("The argument lora_bias=True is not supported for WoRA, please pass use_wora=False")

if self.alora_invocation_tokens is not None and self.task_type != "CAUSAL_LM":
warnings.warn("aLoRA is currently only supported for CAUSAL_LM task.")
Expand Down
110 changes: 84 additions & 26 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(

class LoraLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names: tuple[str, ...] = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
adapter_layer_names: tuple[str, ...] = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B", "lora_wora_alpha", "lora_wora_beta")
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str, ...] = ("r", "lora_alpha", "scaling", "lora_dropout")

Expand All @@ -113,8 +113,11 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.merged_adapters = []
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
self.use_rslora: dict[str, bool] = {}
self.use_wora: dict[str, bool] = {} # for WoRA
self.lora_bias: dict[str, bool] = {}
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA and WoRA
self.lora_wora_alpha = torch.nn.ParameterDict() # for WoRA: learnable alpha scalars
self.lora_wora_beta = torch.nn.ParameterDict() # for WoRA: learnable beta scalars
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
Expand All @@ -127,11 +130,12 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_wora: bool = False, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA. If `use_wora=True`,
same for WoRA.

If there is no fitting variant, return None.

Expand All @@ -152,6 +156,7 @@ def update_layer(
use_dora: bool = False,
use_alora: bool = False,
use_qalora: bool = False,
use_wora: bool = False,
lora_bias: bool = False,
arrow_config: ArrowConfig = None,
qalora_group_size: int = 32,
Expand All @@ -177,6 +182,7 @@ def update_layer(
use_dora=use_dora,
use_alora=use_alora,
use_qalora=use_qalora,
use_wora=use_wora,
qalora_group_size=qalora_group_size,
arrow_config=arrow_config,
)
Expand Down Expand Up @@ -206,6 +212,16 @@ def update_layer(

self.use_dora[adapter_name] = use_dora

self.use_wora[adapter_name] = use_wora

# Initialize WoRA parameters if needed
if use_wora:
self.lora_wora_alpha[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
self.lora_wora_beta[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
# Explicitly ensure they remain trainable (in case something tries to freeze them)
self.lora_wora_alpha[adapter_name].requires_grad_(True)
self.lora_wora_beta[adapter_name].requires_grad_(True)

# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
Expand Down Expand Up @@ -611,6 +627,7 @@ def __init__(
use_rslora: bool = False,
use_dora: bool = False,
use_alora: bool = False,
use_wora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand All @@ -629,25 +646,28 @@ def __init__(
use_rslora=use_rslora,
use_dora=use_dora,
use_alora=use_alora,
use_wora=use_wora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, use_wora: bool = False, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant

return ArrowLinearVariant()

if not use_dora and not use_alora:
if not use_wora and not use_dora and not use_alora:
return None

from .variants import ALoraLinearVariant, DoraLinearVariant
from .variants import ALoraLinearVariant, DoraLinearVariant, WoraLinearVariant

if use_alora:
if use_wora:
return WoraLinearVariant()
elif use_alora:
return ALoraLinearVariant()
else:
return DoraLinearVariant()
Expand Down Expand Up @@ -837,6 +857,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_wora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand All @@ -858,17 +879,21 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_wora=use_wora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_wora: bool = False, **kwargs) -> Optional[LoraVariant]:
if not use_wora and not use_dora:
return None

from .variants import DoraEmbeddingVariant
from .variants import DoraEmbeddingVariant, WoraEmbeddingVariant

return DoraEmbeddingVariant()
if use_wora:
return WoraEmbeddingVariant()
else:
return DoraEmbeddingVariant()

def update_layer(
self,
Expand All @@ -880,6 +905,7 @@ def update_layer(
use_rslora,
use_dora,
lora_bias,
use_wora: bool = False,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
**kwargs,
Expand All @@ -891,7 +917,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_wora=use_wora, arrow_config=arrow_config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -919,6 +945,16 @@ def update_layer(

self.use_dora[adapter_name] = use_dora

self.use_wora[adapter_name] = use_wora

# Initialize WoRA parameters if needed
if use_wora:
self.lora_wora_alpha[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
self.lora_wora_beta[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
# Explicitly ensure they remain trainable
self.lora_wora_alpha[adapter_name].requires_grad_(True)
self.lora_wora_beta[adapter_name].requires_grad_(True)

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
Expand Down Expand Up @@ -1147,6 +1183,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_wora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand Down Expand Up @@ -1176,6 +1213,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_wora=use_wora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
Expand All @@ -1190,6 +1228,7 @@ def update_layer(
use_rslora,
use_dora,
lora_bias,
use_wora: bool = False,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
**kwargs,
Expand All @@ -1208,7 +1247,7 @@ def update_layer(
PeftWarning,
)

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_wora=use_wora, arrow_config=arrow_config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1242,6 +1281,16 @@ def update_layer(

self.use_dora[adapter_name] = use_dora

self.use_wora[adapter_name] = use_wora

# Initialize WoRA parameters if needed
if use_wora:
self.lora_wora_alpha[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
self.lora_wora_beta[adapter_name] = nn.Parameter(torch.tensor(1.0), requires_grad=True)
# Explicitly ensure they remain trainable
self.lora_wora_alpha[adapter_name].requires_grad_(True)
self.lora_wora_beta[adapter_name].requires_grad_(True)

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
Expand Down Expand Up @@ -1452,13 +1501,16 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_wora: bool = False, **kwargs) -> Optional[LoraVariant]:
if not use_wora and not use_dora:
return None

from .variants import DoraConv2dVariant
from .variants import DoraConv2dVariant, WoraConv2dVariant

return DoraConv2dVariant()
if use_wora:
return WoraConv2dVariant()
else:
return DoraConv2dVariant()


class Conv1d(_ConvNd):
Expand All @@ -1469,13 +1521,16 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_wora: bool = False, **kwargs) -> Optional[LoraVariant]:
if not use_wora and not use_dora:
return None

from .variants import DoraConv1dVariant
from .variants import DoraConv1dVariant, WoraConv1dVariant

return DoraConv1dVariant()
if use_wora:
return WoraConv1dVariant()
else:
return DoraConv1dVariant()


class Conv3d(_ConvNd):
Expand All @@ -1486,13 +1541,16 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_wora: bool = False, **kwargs) -> Optional[LoraVariant]:
if not use_wora and not use_dora:
return None

from .variants import DoraConv3dVariant
from .variants import DoraConv3dVariant, WoraConv3dVariant

return DoraConv3dVariant()
if use_wora:
return WoraConv3dVariant()
else:
return DoraConv3dVariant()


class MultiheadAttention(nn.Module, LoraLayer):
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _create_and_replace(
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"use_wora": lora_config.use_wora,
"use_alora": lora_config.alora_invocation_tokens is not None,
"use_qalora": lora_config.use_qalora,
"qalora_group_size": lora_config.qalora_group_size,
Expand Down Expand Up @@ -234,6 +235,7 @@ def _create_and_replace(
init_lora_weights=lora_config.init_lora_weights,
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
use_wora=lora_config.use_wora,
lora_bias=lora_config.lora_bias,
arrow_config=lora_config.arrow_config,
inference_mode=lora_config.inference_mode,
Expand Down
Loading