Skip to content

Commit f0ca962

Browse files
calptTimoImhof
andauthored
Fix default Lora/ (IA)^3 scaling in forward (#770)
Resolves issue described in #760. **IMPORTANT**: this fix restores weights compatibility with adapter-transformers. Compatibility to previous adapters versions is kept via a compat patch. ## Details The current implementation of LoRA/ (IA)^3 in `adapters ` versions < 1.1.0 does not correctly implement adapter states scaling via the LoRA `alpha` attribute, effectively ignoring `alpha` and always applying a scaling of 1.0. This PR restores the correct original behavior as found in adapter-transformers/ original LoRA implementation. As this change breaks all adapters pre-trained using `adapters` versions 0.1.0 - 1.0.1, a backward compatibility patch is introduced that automatically sets `alpha = r` for LoRAs for adapters that were trained using affected versions. This ensures all previous adapters continue to behave exactly as trained (ie give the exact same output using newer versions). --------- Co-authored-by: TimoImhof <[email protected]>
1 parent 702381e commit f0ca962

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"isort>=5.5.4",
3535
"Jinja2==2.11.3",
3636
"nltk",
37+
"packaging",
3738
"parameterized",
3839
"pillow",
3940
"protobuf",
@@ -136,11 +137,12 @@ def deps_list(*pkgs):
136137
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
137138
install_requires = [
138139
deps["transformers"],
140+
deps["packaging"],
139141
]
140142

141143
setup(
142144
name="adapters",
143-
version="1.0.1",
145+
version="1.1.0.dev0",
144146
author="The AdapterHub team and community contributors",
145147
author_email="[email protected]",
146148
description="A Unified Library for Parameter-Efficient and Modular Transfer Learning",

src/adapters/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19-
__version__ = "1.0.1"
19+
__version__ = "1.1.0.dev0"
2020

2121
from typing import TYPE_CHECKING
2222

src/adapters/loading.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Callable, Mapping, Optional, Sequence, Tuple
77

88
import torch
9+
from packaging.version import Version
910

1011

1112
try:
@@ -368,6 +369,23 @@ def _rename_legacy_weights(self, k):
368369
k = k.replace(old, new)
369370
return k
370371

372+
def _fix_backward_compat(self, config):
373+
# Fix error in previous versions for LoRA/ (IA)^3
374+
ADAPTER_PREFIX = "adapters."
375+
MIN_VERSION = Version("1.1.0")
376+
377+
version = config.get("version", "")
378+
if version.startswith(ADAPTER_PREFIX) and Version(version[len(ADAPTER_PREFIX) :]) < MIN_VERSION:
379+
if (
380+
config["config"].get("architecture", None) == "lora"
381+
and config["config"]["r"] != config["config"]["alpha"]
382+
):
383+
logger.warning(
384+
"Loading a LoRA trained using a faulty scaling implementation of a previous library version. Editing the configuration to make sure the adapter works as trained."
385+
"See https://github.com/adapter-hub/adapters/pull/770 for more."
386+
)
387+
config["config"]["alpha"] = config["config"]["r"]
388+
371389
# This method is used to remove unnecessary invertible adapters from task adapters using the old format.
372390
# In the old format, task adapters e.g. using seq_bn config specify inv. adapters but don't use them.
373391
# As inv. adapters would be incorrectly used in the new implementation,
@@ -560,6 +578,8 @@ def load(
560578
# The conversion to a set and then back to a list removes all duplicates
561579
leave_out = list(set(leave_out + config["config"]["leave_out"]))
562580
config["config"]["leave_out"] = leave_out
581+
# Fix issues
582+
self._fix_backward_compat(config)
563583

564584
adapter_name = load_as or config["name"]
565585
# If the adapter is not part of the model, add it

src/adapters/methods/lora.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
100100
hidden_states = hidden_states * gate
101101
else:
102102
gate = None
103+
hidden_states = hidden_states * self.scaling
103104

104105
return hidden_states, gate
105106

@@ -171,6 +172,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
171172
hidden_states = hidden_states * gate
172173
else:
173174
gate = None
175+
hidden_states = hidden_states * self.scaling
174176

175177
return hidden_states, gate
176178

0 commit comments

Comments
 (0)