Skip to content

Commit 72ecc2d

Browse files
committed
Fix in LoRA code
1 parent bf68369 commit 72ecc2d

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

litgpt/lora.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def __init__(
178178
self,
179179
# ↓ this part is for pretrained weights
180180
in_features: int,
181-
out_features: int,
182181
# ↓ the remaining part is for LoRA
183182
head_size: int,
184183
n_head: int,
@@ -199,7 +198,6 @@ def __init__(
199198
200199
Args:
201200
in_features: number of input features of the pretrained weights
202-
out_features: number of output features of the pretrained weights
203201
head_size: size of a single attention head
204202
n_head: number of attention heads
205203
n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
@@ -214,6 +212,7 @@ def __init__(
214212
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
215213
"""
216214
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
215+
out_features = head_size * (n_head + 2 * n_query_groups)
217216
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
218217
self.head_size = head_size
219218
self.n_head = n_head
@@ -229,18 +228,19 @@ def __init__(
229228
# ⚬ out_features: 384 (3 * embedding_size)
230229
# ⚬ r: 2
231230
# ⚬ enable_lora: [True, False, True]
231+
self._all_qkv_shapes = (
232+
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
233+
# might not be equal to `head_size * n_head`, thus we use it directly here
234+
head_size * n_head,
235+
head_size * n_query_groups,
236+
head_size * n_query_groups,
237+
)
232238
if r > 0 and any(enable_lora):
233239
self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128)
234-
enable_q, enable_k, enable_v = enable_lora
235240
# qkv_shapes will be used to split a tensor with weights correctly
236-
qkv_shapes = (
237-
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
238-
# might not be equal to `head_size * n_head`, thus we use it directly here
239-
head_size * n_head * enable_q,
240-
head_size * n_query_groups * enable_k,
241-
head_size * n_query_groups * enable_v,
242-
)
243-
self.qkv_shapes = [s for s in qkv_shapes if s]
241+
self.qkv_shapes = [
242+
s for s, e in zip(self._all_qkv_shapes, enable_lora) if e
243+
]
244244
self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2))
245245
# Notes about shapes above
246246
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
@@ -266,15 +266,13 @@ def lora_ind(self) -> torch.Tensor:
266266
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
267267
# Indices are needed to properly pad weight updates with zeros.
268268
if not hasattr(self, "_lora_ind"):
269-
enable_q, enable_k, enable_v = self.enable_lora
270-
kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups)
269+
off = 0
271270
lora_ind = []
272-
if enable_q:
273-
lora_ind.extend(range(0, self.linear.in_features))
274-
if enable_k:
275-
lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size))
276-
if enable_v:
277-
lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features))
271+
for enable, size in zip(self.enable_lora, self._all_qkv_shapes):
272+
if enable:
273+
lora_ind.extend(range(off, off + size))
274+
off += size
275+
assert len(lora_ind) == sum(self.qkv_shapes) # Sanity check
278276
self.register_buffer(
279277
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
280278
)
@@ -527,10 +525,8 @@ class CausalSelfAttention(BaseCausalSelfAttention):
527525
def __init__(self, config: Config, block_idx: int) -> None:
528526
super().__init__(config, block_idx)
529527
# key, query, value projections for all heads, but in a batch
530-
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
531528
self.qkv = LoRAQKVLinear(
532529
in_features=config.n_embd,
533-
out_features=shape,
534530
r=config.lora_r,
535531
lora_alpha=config.lora_alpha,
536532
lora_dropout=config.lora_dropout,

tests/test_lora.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs):
248248
original_linear = torch.nn.Linear
249249
# Our bnb does this sort of monkey patching
250250
torch.nn.Linear = MyLinear
251-
layer = LoRAQKVLinear(1, 1, 1, 1, 1)
251+
layer = LoRAQKVLinear(1, 1, 1, 1)
252252
assert isinstance(layer.linear, original_linear)
253253
torch.nn.Linear = original_linear
254254

@@ -355,7 +355,7 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap
355355
def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora):
356356
C = 12
357357
layer = LoRAQKVLinear(
358-
C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora
358+
C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora
359359
)
360360
x = torch.randn((1, 1, C))
361361
a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A
@@ -389,7 +389,7 @@ def test_lora_linear_weights_merged_status(rank, expected_merged):
389389
)
390390
def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged):
391391
C = 10
392-
layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
392+
layer = LoRAQKVLinear(C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
393393
assert not layer.merged
394394
layer.merge()
395395
assert layer.merged == expected_merged
@@ -909,14 +909,11 @@ def test_zero_pad_cpu_and_mocked_mps():
909909
n_head = 12
910910
n_query_groups = 3
911911
in_features = 128
912-
kv_embed_dim = in_features // (n_head // n_query_groups)
913-
out_features = in_features + 2 * kv_embed_dim
914-
enable_lora = [True, False, True]
912+
enable_lora = (True, False, True)
915913
r = 4
916914

917915
model = LoRAQKVLinear(
918916
in_features=in_features,
919-
out_features=out_features,
920917
head_size=head_size,
921918
n_head=n_head,
922919
n_query_groups=n_query_groups,
@@ -926,7 +923,7 @@ def test_zero_pad_cpu_and_mocked_mps():
926923

927924
batch_size = 64
928925
seq_len = 64
929-
embed_dim = 160
926+
embed_dim = head_size * (n_head + n_query_groups)
930927
x = torch.randn(batch_size, seq_len, embed_dim)
931928

932929
result_cpu = model.zero_pad(x)
@@ -1167,3 +1164,38 @@ def test_load_from_full_model_state_dict():
11671164

11681165
output_cpu_offload = model_cpu_offload(x)
11691166
assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)
1167+
1168+
1169+
def test_forward_qwen3_4b():
1170+
"""
1171+
Tests whether LoRA forward works with the Qwen3-4B model, whose
1172+
transformer block has some non-standard configs:
1173+
1174+
* `n_embd != n_head * head_size`
1175+
* `n_head != n_query_groups`
1176+
* LoRA adapters added to queries and values, but not keys
1177+
1178+
"""
1179+
device = torch.device("cpu")
1180+
T = 20
1181+
config = Config.from_name(
1182+
"Qwen3-4B",
1183+
block_size=T,
1184+
n_layer=2,
1185+
vocab_size=128,
1186+
padded_vocab_size=128,
1187+
lora_r=8,
1188+
lora_alpha=16,
1189+
lora_dropout=0.05,
1190+
lora_query=True,
1191+
lora_key=False,
1192+
lora_value=True,
1193+
)
1194+
model = LoRAGPT(config)
1195+
x = torch.randint(
1196+
low=0,
1197+
high=config.padded_vocab_size,
1198+
size=(2, T),
1199+
device=device,
1200+
)
1201+
y = model(x)

0 commit comments

Comments
 (0)