Skip to content

Commit bcace97

Browse files
authored
Fix LoRA parallel composition (#752)
Currently, many model implementation don't handle bsz replication caused by parallel composition correctly for attention matrix LoRAs. This PRs loosens bsz reshaping to enable this. Also adds a test case for LoRA parallel. Resolves #744.
1 parent 0c1039b commit bcace97

File tree

12 files changed

+99
-38
lines changed

12 files changed

+99
-38
lines changed

src/adapters/models/bart/modeling_bart.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
import torch.utils.checkpoint
2020
from torch import nn
2121

22-
from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer
22+
from transformers.models.bart.modeling_bart import (
23+
BartAttention,
24+
BartDecoderLayer,
25+
BartEncoderLayer,
26+
BartFlashAttention2,
27+
BartSdpaAttention,
28+
)
2329
from transformers.utils import logging
2430

2531
from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
@@ -32,6 +38,10 @@
3238
class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
3339
"""Multi-headed attention from 'Attention Is All You Need' paper"""
3440

41+
# Loosen constraint on batch_size to allow parallel adapter composition
42+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
43+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
44+
3545
def forward(
3646
self,
3747
hidden_states: torch.Tensor,
@@ -164,7 +174,12 @@ def forward(
164174
return attn_output, attn_weights_reshaped, past_key_value
165175

166176

167-
class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention):
177+
class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartFlashAttention2):
178+
179+
# Loosen constraint on batch_size to allow parallel adapter composition
180+
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
181+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim)
182+
168183
def forward(
169184
self,
170185
hidden_states: torch.Tensor,
@@ -275,7 +290,12 @@ def forward(
275290
return attn_output, attn_weights, past_key_value
276291

277292

278-
class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
293+
class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartSdpaAttention):
294+
295+
# Loosen constraint on batch_size to allow parallel adapter composition
296+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
297+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
298+
279299
def forward(
280300
self,
281301
hidden_states: torch.Tensor,

src/adapters/models/distilbert/modeling_distilbert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def forward(
6161

6262
def shape(x: torch.Tensor) -> torch.Tensor:
6363
"""separate heads"""
64-
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
64+
# keep first dim due to parallel composition
65+
return x.view(x.shape[0], -1, self.n_heads, dim_per_head).transpose(1, 2)
6566

6667
def unshape(x: torch.Tensor) -> torch.Tensor:
6768
"""group heads"""

src/adapters/models/llama/modeling_llama.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ def forward(
8585
key_states = self.k_proj(hidden_states)
8686
value_states = self.v_proj(hidden_states)
8787

88-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
89-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
90-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
88+
# Loosen constraint on batch_size to allow parallel adapter composition
89+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
90+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
91+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
9192

9293
# >>> START AH Changes <<<
9394
query_states, key_states, value_states = match_attn_matrices_for_parallel(
@@ -188,9 +189,11 @@ def forward(
188189
# Flash attention requires the input to have the shape
189190
# batch_size x seq_length x head_dim x hidden_dim
190191
# therefore we just need to keep the original shape
191-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
193-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
192+
193+
# Loosen constraint on batch_size to allow parallel adapter composition
194+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
195+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
196+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
194197

195198
# >>> START AH Changes <<<
196199
query_states, key_states, value_states = match_attn_matrices_for_parallel(
@@ -320,9 +323,10 @@ def forward(
320323
key_states = self.k_proj(hidden_states)
321324
value_states = self.v_proj(hidden_states)
322325

323-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
324-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
325-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
326+
# Loosen constraint on batch_size to allow parallel adapter composition
327+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
328+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
326330

327331
# >>> START AH Changes <<<
328332
query_states, key_states, value_states = match_attn_matrices_for_parallel(

src/adapters/models/mbart/modeling_mbart.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
class MBartAttentionWithAdapters(BartAttentionAdaptersMixin, MBartAttention):
2929
"""Multi-headed attention from 'Attention Is All You Need' paper"""
3030

31+
# Loosen constraint on batch_size to allow parallel adapter composition
32+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
33+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
34+
3135
def forward(
3236
self,
3337
hidden_states: torch.Tensor,

src/adapters/models/mistral/modeling_mistral.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,16 @@ def forward(
6868
key_states = self.k_proj(hidden_states)
6969
value_states = self.v_proj(hidden_states)
7070

71-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
72-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
73-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
74-
7571
# >>> START AH Changes <<<
72+
# Loosen constraint on batch_size to allow parallel adapter composition
73+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
74+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
75+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
76+
7677
query_states, key_states, value_states = match_attn_matrices_for_parallel(
7778
query_states, key_states, value_states
7879
)
79-
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
80+
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
8081
# >>> END AH Changes <<<
8182

8283
cos, sin = self.rotary_emb(value_states, position_ids)
@@ -153,15 +154,16 @@ def forward(
153154
key_states = self.k_proj(hidden_states)
154155
value_states = self.v_proj(hidden_states)
155156

156-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
157-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
158-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
159-
160157
# >>> START AH Changes <<<
158+
# Loosen constraint on batch_size to allow parallel adapter composition
159+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
160+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
161+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
162+
161163
query_states, key_states, value_states = match_attn_matrices_for_parallel(
162164
query_states, key_states, value_states
163165
)
164-
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
166+
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
165167
# >>> END AH Changes <<<
166168

167169
kv_seq_len = key_states.shape[-2]
@@ -310,15 +312,16 @@ def forward(
310312
key_states = self.k_proj(hidden_states)
311313
value_states = self.v_proj(hidden_states)
312314

313-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
314-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
315-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
316-
317315
# >>> START AH Changes <<<
316+
# Loosen constraint on batch_size to allow parallel adapter composition
317+
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
318+
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319+
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320+
318321
query_states, key_states, value_states = match_attn_matrices_for_parallel(
319322
query_states, key_states, value_states
320323
)
321-
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
324+
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
322325
# >>> END AH Changes <<<
323326

324327
cos, sin = self.rotary_emb(value_states, position_ids)

src/adapters/models/mt5/modeling_mt5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def forward(
8484

8585
def shape(states):
8686
"""projection"""
87-
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
87+
# keep first dim due to parallel composition
88+
return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
8889

8990
def unshape(states):
9091
"""reshape"""

src/adapters/models/plbart/modeling_plbart.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
class PLBartAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):
3737
"""Multi-headed attention from 'Attention Is All You Need' paper"""
3838

39+
# Loosen constraint on batch_size to allow parallel adapter composition
40+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
41+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
42+
3943
def forward(
4044
self,
4145
hidden_states: torch.Tensor,
@@ -169,6 +173,11 @@ def forward(
169173

170174

171175
class PLBartFlashAttention2WithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):
176+
177+
# Loosen constraint on batch_size to allow parallel adapter composition
178+
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
179+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim)
180+
172181
def forward(
173182
self,
174183
hidden_states: torch.Tensor,
@@ -280,6 +289,11 @@ def forward(
280289

281290

282291
class PLBartSdpaAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):
292+
293+
# Loosen constraint on batch_size to allow parallel adapter composition
294+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
295+
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
296+
283297
def forward(
284298
self,
285299
hidden_states: torch.Tensor,

src/adapters/models/t5/modeling_t5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def forward(
8484

8585
def shape(states):
8686
"""projection"""
87-
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
87+
# keep first dim due to parallel composition
88+
return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
8889

8990
def unshape(states):
9091
"""reshape"""

tests/composition/test_parallel.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33

44
import torch
55

6-
from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig, SeqBnConfig, T5AdapterModel
6+
from adapters import (
7+
ADAPTER_MODEL_MAPPING,
8+
AutoAdapterModel,
9+
LoRAConfig,
10+
PrefixTuningConfig,
11+
SeqBnConfig,
12+
T5AdapterModel,
13+
)
714
from adapters.composition import BatchSplit, Parallel
815
from adapters.models.bert_generation.adapter_model import BertGenerationAdapterModel
916
from transformers import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, Trainer, TrainingArguments
@@ -276,13 +283,16 @@ def run_parallel_training_equivalent_to_single(self, adapter_config):
276283
self.assertTrue(torch.allclose(v, state_dict[k.replace(b1, b2)], atol=1e-5))
277284

278285
def test_parallel_training_bottleneck(self):
279-
self.run_parallel_training_test(SeqBnConfig(), "adapters.{}")
286+
self.run_parallel_training_test(SeqBnConfig(reduction_factor=48), "adapters.{}")
287+
288+
def test_parallel_training_lora(self):
289+
self.run_parallel_training_test(LoRAConfig(r=1), "loras.{}")
280290

281291
def test_parallel_training_prefix_tuning(self):
282292
self.run_parallel_training_test(PrefixTuningConfig(), "prefix_tunings.{}")
283293

284294
def test_parallel_training_equivalent_to_single_bottleneck(self):
285-
self.run_parallel_training_equivalent_to_single(SeqBnConfig())
295+
self.run_parallel_training_equivalent_to_single(SeqBnConfig(reduction_factor=48))
286296

287297
def test_parallel_training_equivalent_to_single_prefix_tuning(self):
288298
self.run_parallel_training_equivalent_to_single(PrefixTuningConfig())
@@ -291,8 +301,8 @@ def test_parallel_training_single_forward_pass(self):
291301
model = AutoAdapterModel.from_config(self.config())
292302
model.eval()
293303

294-
a1, a2 = self.create_twin_adapters(model, "a", SeqBnConfig())
295-
b1, b2 = self.create_twin_adapters(model, "b", SeqBnConfig())
304+
a1, a2 = self.create_twin_adapters(model, "a", SeqBnConfig(reduction_factor=48))
305+
b1, b2 = self.create_twin_adapters(model, "b", SeqBnConfig(reduction_factor=48))
296306

297307
state_dict = model.state_dict()
298308
for k, v in state_dict.items():

tests/test_deberta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class DebertaAdapterTest(
4242
DebertaAdapterTestBase,
4343
unittest.TestCase,
4444
):
45-
pass
45+
def test_parallel_training_lora(self):
46+
self.skipTest("Not supported for DeBERTa")
4647

4748

4849
@require_torch

0 commit comments

Comments
 (0)