diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index f57ee970..2e363c4e 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -105,7 +105,7 @@ def forward( # instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding # weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated # weights for q, k, v) and then split the result along `embedding size` dimension - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C) + q, k, v = self.c_attn(x).split(C, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C) # in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split # embedding size (C) dimension into num_heads (nh) and head_size (hs) @@ -147,19 +147,19 @@ def forward( # "Adapters are applied to the topmost layers to better tune the language # representations with higher-level semantics". if self.block_idx >= self.adapter_start_layer: + aT = self.adapter_prompt_length if adapter_kv_cache is not None: ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs) else: - prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd) - aT = prefix.size(1) - _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C) + prefix = self.adapter_wte.weight.reshape(1, aT, C) + _, ak, av = self.c_attn(prefix).split(C, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C) ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs) av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs) adapter_kv_cache = (ak, av) # Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output # obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper. - amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT) + amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) # ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs) ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs) y = y + self.gating_factor * ay @@ -223,14 +223,14 @@ class LLaMA(llama.LLaMA): def __init__(self, config: LLaMAConfig) -> None: nn.Module.__init__(self) - assert config.vocab_size is not None + assert config.padded_vocab_size is not None assert config.block_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) self.transformer = nn.ModuleDict( dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), ln_f=RMSNorm(config.n_embd), ) @@ -297,7 +297,7 @@ def forward( x = self.transformer.ln_f(x) # (B, T, n_embd) - logits = self.lm_head(x) # (B, T, vocab_size) + logits = self.lm_head(x) # (B, T, padded_vocab_size) return logits diff --git a/tests/test_model.py b/tests/test_model.py index 3abc4843..1dfd7983 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -184,7 +184,7 @@ def test_adapter_parity(orig_llama_adapter): dim=n_embd, n_layers=n_layer, n_heads=n_head, - vocab_size=vocab_size, + vocab_size=llama_config.padded_vocab_size, norm_eps=1e-5, max_seq_len=block_size, adapter_len=adapter_prompt_length,