Skip to content

Commit 7550099

Browse files
authored
Allow specifying adapter dtype in AdapterConfig (#767)
Aims to fix #766 Backwards compatible, since `dtype` defaults to `None` if not set in `AdapterConfig`.
1 parent e591965 commit 7550099

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/adapters/configuration/adapter_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig):
483483
Place a trainable gating module besides the added parameter module to control module activation. This is
484484
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
485485
`merge_adapter()`.
486+
dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None.
486487
"""
487488

488489
architecture: Optional[str] = "lora"
@@ -499,6 +500,7 @@ class LoRAConfig(AdapterConfig):
499500
composition_mode: str = "add"
500501
init_weights: str = "lora"
501502
use_gating: bool = False
503+
dtype: Optional[str] = None
502504

503505

504506
@dataclass(eq=False)
@@ -521,6 +523,7 @@ class IA3Config(LoRAConfig):
521523
composition_mode: str = "scale"
522524
init_weights: str = "ia3"
523525
use_gating: bool = False
526+
dtype: Optional[str] = None
524527

525528

526529
@dataclass(eq=False)
@@ -540,6 +543,7 @@ class ReftConfig(AdapterConfig):
540543
subtract_projection (bool): If True, subtract the projection of the input.
541544
dropout (float): The dropout rate used in the intervention layer.
542545
non_linearity (str): The activation function used in the intervention layer.
546+
dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
543547
"""
544548

545549
layers: Union[Literal["all"], List[int]]
@@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig):
551555
subtract_projection = True
552556
dropout: float = 0.05
553557
non_linearity: Optional[str] = None
558+
dtype: Optional[str] = None
554559

555560
architecture: str = "reft"
556561

@@ -569,6 +574,7 @@ class LoReftConfig(ReftConfig):
569574
r: int = 1
570575
orthogonality: bool = True
571576
tied_weights: bool = False
577+
dtype: Optional[str] = None
572578

573579

574580
@dataclass(eq=False)
@@ -583,6 +589,7 @@ class NoReftConfig(ReftConfig):
583589
r: int = 1
584590
orthogonality: bool = False
585591
tied_weights: bool = False
592+
dtype: Optional[str] = None
586593

587594

588595
@dataclass(eq=False)
@@ -598,6 +605,7 @@ class DiReftConfig(ReftConfig):
598605
orthogonality: bool = False
599606
tied_weights: bool = False
600607
subtract_projection = False
608+
dtype: Optional[str] = None
601609

602610

603611
class ConfigUnion(AdapterConfig):

src/adapters/methods/lora.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def __init__(
5151
else:
5252
self.lora_dropout = lambda x: x
5353

54+
dtype = getattr(torch, config.dtype) if config.dtype else None
5455
# Actual trainable parameters
55-
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape))
56-
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
56+
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype))
57+
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype))
5758
self.scaling = self.lora_alpha / self.r
5859

5960
# For compatibility with (IA)^3, allow all init_weights types here.

src/adapters/methods/reft.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
import torch
44
import torch.nn as nn
@@ -18,12 +18,13 @@ def __init__(
1818
subtract_projection: bool = True,
1919
non_linearity: str = None,
2020
dropout: float = 0.0,
21+
dtype: Optional[torch.dtype] = None,
2122
):
2223
super().__init__()
2324
self.orthogonal = orthogonal
24-
self.learned_source = nn.Linear(in_dim, r_dim, bias=True)
25+
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)
2526

26-
projection = nn.Linear(in_dim, r_dim, bias=False)
27+
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)
2728
if orthogonal:
2829
self.projection = nn.utils.parametrizations.orthogonal(projection)
2930
else:
@@ -50,6 +51,7 @@ def __init__(self, in_features: int, config: ReftConfig):
5051
self.suffix_positions = config.suffix_positions
5152
self.tied_weights = config.tied_weights
5253
n_units = 1 if config.tied_weights else 2
54+
dtype = getattr(torch, config.dtype) if config.dtype else None
5355
self.units = nn.ModuleList(
5456
[
5557
ReftUnit(
@@ -59,6 +61,7 @@ def __init__(self, in_features: int, config: ReftConfig):
5961
config.subtract_projection,
6062
config.non_linearity,
6163
config.dropout,
64+
dtype,
6265
)
6366
for _ in range(n_units)
6467
]

0 commit comments

Comments
 (0)