From a1acde21ed410e4f430cab39e3cff7612164d561 Mon Sep 17 00:00:00 2001 From: "Andrei.Aksionov" Date: Fri, 2 Jun 2023 17:26:04 +0300 Subject: [PATCH 1/4] CausalSelfAttention.forward: n_embd --> C --- lit_llama/adapter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index f57ee970..ad0eb9b9 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -105,7 +105,7 @@ def forward( # instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding # weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated # weights for q, k, v) and then split the result along `embedding size` dimension - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C) + q, k, v = self.c_attn(x).split(C, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C) # in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split # embedding size (C) dimension into num_heads (nh) and head_size (hs) @@ -150,9 +150,9 @@ def forward( if adapter_kv_cache is not None: ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs) else: - prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd) - aT = prefix.size(1) - _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C) + aT = self.adapter_prompt_length + prefix = self.adapter_wte.weight.reshape(1, aT, C) + _, ak, av = self.c_attn(prefix).split(C, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C) ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs) av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs) adapter_kv_cache = (ak, av) From 8e7501bed28d5eb7371b3790c526f66d6519bedc Mon Sep 17 00:00:00 2001 From: "Andrei.Aksionov" Date: Fri, 2 Jun 2023 17:30:19 +0300 Subject: [PATCH 2/4] Adapter version of LLaMA class: vocab_size --> padded_vocab_size --- lit_llama/adapter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index ad0eb9b9..3c3f90d4 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -223,14 +223,14 @@ class LLaMA(llama.LLaMA): def __init__(self, config: LLaMAConfig) -> None: nn.Module.__init__(self) - assert config.vocab_size is not None + assert config.padded_vocab_size is not None assert config.block_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) self.transformer = nn.ModuleDict( dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), ln_f=RMSNorm(config.n_embd), ) @@ -297,7 +297,7 @@ def forward( x = self.transformer.ln_f(x) # (B, T, n_embd) - logits = self.lm_head(x) # (B, T, vocab_size) + logits = self.lm_head(x) # (B, T, padded_vocab_size) return logits From f1cc27b2f454f3573fdf2a0f55fb79ed5dc8a98a Mon Sep 17 00:00:00 2001 From: "Andrei.Aksionov" Date: Fri, 2 Jun 2023 21:36:16 +0300 Subject: [PATCH 3/4] Put `aT=self.adapter_prompt_length` above and use it instead of asking for shape of ak/av --- lit_llama/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index 3c3f90d4..2e363c4e 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -147,10 +147,10 @@ def forward( # "Adapters are applied to the topmost layers to better tune the language # representations with higher-level semantics". if self.block_idx >= self.adapter_start_layer: + aT = self.adapter_prompt_length if adapter_kv_cache is not None: ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs) else: - aT = self.adapter_prompt_length prefix = self.adapter_wte.weight.reshape(1, aT, C) _, ak, av = self.c_attn(prefix).split(C, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C) ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs) @@ -159,7 +159,7 @@ def forward( # Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output # obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper. - amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT) + amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) # ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs) ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs) y = y + self.gating_factor * ay From cdd05ddbaf9d348f2e614fdc32baff3db43e5a55 Mon Sep 17 00:00:00 2001 From: "Andrei.Aksionov" Date: Fri, 2 Jun 2023 22:28:32 +0300 Subject: [PATCH 4/4] Adapter test: original model should also use `padded_vocab_size` --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 3abc4843..1dfd7983 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -184,7 +184,7 @@ def test_adapter_parity(orig_llama_adapter): dim=n_embd, n_layers=n_layer, n_heads=n_head, - vocab_size=vocab_size, + vocab_size=llama_config.padded_vocab_size, norm_eps=1e-5, max_seq_len=block_size, adapter_len=adapter_prompt_length,