Skip to content

Commit 0c1039b

Browse files
authored
Fix output_embeddings get/ set for empty active head (#754)
Fixes #742.
1 parent 470af8a commit 0c1039b

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/adapters/heads/model_mixin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def _init_head_modules(self):
9494
def get_output_embeddings(self) -> Union[nn.Module, List[nn.Module]]:
9595
# Only gets the output embeddings for the currently active head
9696
embeddings = []
97+
if not self._active_heads:
98+
return None
9799
for head_name in self._active_heads:
98100
if head_name in self.heads:
99101
head = self.heads[head_name]
@@ -109,6 +111,8 @@ def get_output_embeddings(self) -> Union[nn.Module, List[nn.Module]]:
109111

110112
def set_output_embeddings(self, new_embeddings: Union[nn.Module, List[nn.Module]]):
111113
# Only sets the output embeddings for the currently active head
114+
if not self._active_heads:
115+
return
112116
if not isinstance(new_embeddings, list):
113117
new_embeddings = [new_embeddings] * len(self._active_heads)
114118
for head_name, emb in zip(self._active_heads, new_embeddings):

tests/test_adapter_heads.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ def test_delete_head(self):
258258
self.assertFalse(name in model.config.prediction_heads)
259259
self.assertNotEqual(name, model.active_head)
260260

261+
# add head again
262+
self.add_head(model, name)
263+
self.assertTrue(name in model.heads)
264+
self.assertTrue(name in model.config.prediction_heads)
265+
self.assertEqual(name, model.active_head)
266+
261267
def test_adapter_with_head(self):
262268
model1, model2 = create_twin_models(self.model_class, self.config)
263269

0 commit comments

Comments
 (0)