diff --git a/litgpt/lora.py b/litgpt/lora.py index 6739b5b040..e608c00fe4 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -178,7 +178,6 @@ def __init__( self, # ↓ this part is for pretrained weights in_features: int, - out_features: int, # ↓ the remaining part is for LoRA head_size: int, n_head: int, @@ -199,7 +198,6 @@ def __init__( Args: in_features: number of input features of the pretrained weights - out_features: number of output features of the pretrained weights head_size: size of a single attention head n_head: number of attention heads n_query_groups: number of query groups (see diagram in `litgpt/config.py`) @@ -214,6 +212,7 @@ def __init__( and `value` but keep `key` without weight updates we should pass `[True, False, True]` """ super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + out_features = head_size * (n_head + 2 * n_query_groups) self.linear = torch.nn.Linear(in_features, out_features, **kwargs) self.head_size = head_size self.n_head = n_head @@ -229,18 +228,17 @@ def __init__( # ⚬ out_features: 384 (3 * embedding_size) # ⚬ r: 2 # ⚬ enable_lora: [True, False, True] + self._all_qkv_shapes = ( + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + head_size * n_head, + head_size * n_query_groups, + head_size * n_query_groups, + ) if r > 0 and any(enable_lora): self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128) - enable_q, enable_k, enable_v = enable_lora # qkv_shapes will be used to split a tensor with weights correctly - qkv_shapes = ( - # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) - # might not be equal to `head_size * n_head`, thus we use it directly here - head_size * n_head * enable_q, - head_size * n_query_groups * enable_k, - head_size * n_query_groups * enable_v, - ) - self.qkv_shapes = [s for s in qkv_shapes if s] + self.qkv_shapes = [s for s, e in zip(self._all_qkv_shapes, enable_lora) if e] self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2)) # Notes about shapes above # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; @@ -266,15 +264,13 @@ def lora_ind(self) -> torch.Tensor: """Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used.""" # Indices are needed to properly pad weight updates with zeros. if not hasattr(self, "_lora_ind"): - enable_q, enable_k, enable_v = self.enable_lora - kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups) + off = 0 lora_ind = [] - if enable_q: - lora_ind.extend(range(0, self.linear.in_features)) - if enable_k: - lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size)) - if enable_v: - lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features)) + for enable, size in zip(self.enable_lora, self._all_qkv_shapes): + if enable: + lora_ind.extend(range(off, off + size)) + off += size + assert len(lora_ind) == sum(self.qkv_shapes) # Sanity check self.register_buffer( "_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False ) @@ -527,10 +523,8 @@ class CausalSelfAttention(BaseCausalSelfAttention): def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) # key, query, value projections for all heads, but in a batch - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = LoRAQKVLinear( in_features=config.n_embd, - out_features=shape, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, diff --git a/tests/test_lora.py b/tests/test_lora.py index 1585ea4449..cd1dfe4adf 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs): original_linear = torch.nn.Linear # Our bnb does this sort of monkey patching torch.nn.Linear = MyLinear - layer = LoRAQKVLinear(1, 1, 1, 1, 1) + layer = LoRAQKVLinear(1, 1, 1, 1) assert isinstance(layer.linear, original_linear) torch.nn.Linear = original_linear @@ -354,9 +354,7 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap ) def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora): C = 12 - layer = LoRAQKVLinear( - C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora - ) + layer = LoRAQKVLinear(C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora) x = torch.randn((1, 1, C)) a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A b = layer.lora_B.data.unsqueeze(-1) @@ -389,7 +387,7 @@ def test_lora_linear_weights_merged_status(rank, expected_merged): ) def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged): C = 10 - layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora) + layer = LoRAQKVLinear(C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora) assert not layer.merged layer.merge() assert layer.merged == expected_merged @@ -909,14 +907,11 @@ def test_zero_pad_cpu_and_mocked_mps(): n_head = 12 n_query_groups = 3 in_features = 128 - kv_embed_dim = in_features // (n_head // n_query_groups) - out_features = in_features + 2 * kv_embed_dim - enable_lora = [True, False, True] + enable_lora = (True, False, True) r = 4 model = LoRAQKVLinear( in_features=in_features, - out_features=out_features, head_size=head_size, n_head=n_head, n_query_groups=n_query_groups, @@ -926,7 +921,7 @@ def test_zero_pad_cpu_and_mocked_mps(): batch_size = 64 seq_len = 64 - embed_dim = 160 + embed_dim = head_size * (n_head + n_query_groups) x = torch.randn(batch_size, seq_len, embed_dim) result_cpu = model.zero_pad(x) @@ -1167,3 +1162,38 @@ def test_load_from_full_model_state_dict(): output_cpu_offload = model_cpu_offload(x) assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size) + + +def test_forward_qwen3_4b(): + """ + Tests whether LoRA forward works with the Qwen3-4B model, whose + transformer block has some non-standard configs: + + * `n_embd != n_head * head_size` + * `n_head != n_query_groups` + * LoRA adapters added to queries and values, but not keys + + """ + device = torch.device("cpu") + T = 20 + config = Config.from_name( + "Qwen3-4B", + block_size=T, + n_layer=2, + vocab_size=128, + padded_vocab_size=128, + lora_r=8, + lora_alpha=16, + lora_dropout=0.05, + lora_query=True, + lora_key=False, + lora_value=True, + ) + model = LoRAGPT(config) + x = torch.randint( + low=0, + high=config.padded_vocab_size, + size=(2, T), + device=device, + ) + y = model(x)