Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions lit_llama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(
# instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
# weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
# weights for q, k, v) and then split the result along `embedding size` dimension
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
q, k, v = self.c_attn(x).split(C, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)

# in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
# embedding size (C) dimension into num_heads (nh) and head_size (hs)
Expand Down Expand Up @@ -147,19 +147,19 @@ def forward(
# "Adapters are applied to the topmost layers to better tune the language
# representations with higher-level semantics".
if self.block_idx >= self.adapter_start_layer:
aT = self.adapter_prompt_length
if adapter_kv_cache is not None:
ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
else:
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
aT = prefix.size(1)
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
prefix = self.adapter_wte.weight.reshape(1, aT, C)
_, ak, av = self.c_attn(prefix).split(C, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
adapter_kv_cache = (ak, av)

# Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
# obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT)
amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)
# ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs)
y = y + self.gating_factor * ay
Expand Down Expand Up @@ -223,14 +223,14 @@ class LLaMA(llama.LLaMA):

def __init__(self, config: LLaMAConfig) -> None:
nn.Module.__init__(self)
assert config.vocab_size is not None
assert config.padded_vocab_size is not None
assert config.block_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
ln_f=RMSNorm(config.n_embd),
)
Expand Down Expand Up @@ -297,7 +297,7 @@ def forward(

x = self.transformer.ln_f(x) # (B, T, n_embd)

logits = self.lm_head(x) # (B, T, vocab_size)
logits = self.lm_head(x) # (B, T, padded_vocab_size)

return logits

Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_adapter_parity(orig_llama_adapter):
dim=n_embd,
n_layers=n_layer,
n_heads=n_head,
vocab_size=vocab_size,
vocab_size=llama_config.padded_vocab_size,
norm_eps=1e-5,
max_seq_len=block_size,
adapter_len=adapter_prompt_length,
Expand Down