Skip to content

Commit 0819c7f

Browse files
Adapter annotated (#352)
1 parent 311b422 commit 0819c7f

File tree

1 file changed

+84
-24
lines changed

1 file changed

+84
-24
lines changed

lit_llama/adapter.py

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,41 @@
22
33
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
44
https://arxiv.org/abs/2303.16199
5+
6+
| Prefix cross-attention
7+
|
8+
┌─────────────────┐ | ┌──────────────────┐
9+
┆ x ┆ | ┆ prefix ┆
10+
└─────────────────┘ | └──────────────────┘
11+
| | |
12+
▼ | ▼
13+
┌──────────────────┐ | ┌─────────────────────┐
14+
┆ self-attention ┆ --------------------------------------------------------------┐ ┆ linear projection ┆
15+
└──────────────────┘ | ┆ └─────────────────────┘
16+
| | ┆ | \
17+
▼ | ▼ ▼ ▼
18+
╭───╮ ┌────────────────┐ ╭───╮ ┌──────────────────────────┐ | ┌─────────┐ ┌──────────────┐ ┌────────────────┐
19+
┆ + ┆ ◀── ┆ gating factor ┆-┆ x ┆-┆ prefix cross-attention ┆ | ┆ query ┆ ┆ prefix key ┆ ┆ prefix value ┆
20+
╰───╯ └────────────────┘ ╰───╯ └──────────────────────────┘ | └─────────┘ └──────────────┘ └────────────────┘
21+
| | \ | /
22+
▼ | ▼ ▼ ▼
23+
| ┌────────────────────────────────┐
24+
| ┆ scaled dot-product attention ┆
25+
| └────────────────────────────────┘
26+
27+
28+
In order to inject learnable information from the prefix to pretrained weights we need to sum outputs from
29+
self-attention and prefix cross-attention (times gating factor). For prefix cross-attention we need `query` (from
30+
self-attention as a result of linear projection), `prefix key` and `prefix value` (from cross-attention as a result of
31+
linear projection).
32+
The output of prefix cross-attention is multiplied by gating factor, which is a learnable parameter that is needed to
33+
avoid potential disruption of pretrained weights caused by incorporating randomly initialized tensors. This factor is
34+
initialized with zeros to avoid noise from the adaption prompts at the early training stage.
35+
More about it: https://lightning.ai/pages/community/article/understanding-llama-adapters/
36+
37+
Notes about implementation: as per paper adapter's prefix is concatenated with the input, while here outputs of
38+
self-attention and prefix cross-attention are summed. Both variants are mathematically equivalent:
39+
https://github.com/ZrrSkywalker/LLaMA-Adapter/issues/47
540
"""
641
# mypy: ignore-errors
742
from dataclasses import dataclass
@@ -37,7 +72,8 @@ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
3772
if block_idx >= config.adapter_start_layer:
3873
# adapter embedding layer
3974
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
40-
# gate for adaption
75+
# a learnable gating factor (to avoid potential disruption of pretrained weights) initialized with zeros (to
76+
# avoid noise from adaption prompts at the early training stage)
4177
self.gating_factor = torch.nn.Parameter(torch.zeros(1, config.n_head, 1, 1))
4278

4379
self.n_head = config.n_head
@@ -57,57 +93,81 @@ def forward(
5793
kv_cache: Optional[KVCache] = None,
5894
adapter_kv_cache: Optional[KVCache] = None,
5995
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
60-
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
61-
62-
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
63-
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
64-
96+
# notation:
97+
# - B | batch
98+
# - T | time-step (sequence length)
99+
# - C | embeddings size (n_embd) = head size * num heads
100+
# - hs | head size
101+
# - nh | number of heads
102+
103+
B, T, C = x.size()
104+
105+
# instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
106+
# weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
107+
# weights for q, k, v) and then split the result along `embedding size` dimension
108+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
109+
110+
# in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
111+
# embedding size (C) dimension into num_heads (nh) and head_size (hs)
65112
head_size = C // self.n_head
66113
k = k.view(B, T, self.n_head, head_size)
67114
q = q.view(B, T, self.n_head, head_size)
68115
v = v.view(B, T, self.n_head, head_size)
69116

70-
q = apply_rope(q, rope)
71-
k = apply_rope(k, rope)
117+
# "Unlike standard positional embeddings rotary embeddings must be applied at every layer"
118+
q = apply_rope(q, rope) # (B, T, nh, hs)
119+
k = apply_rope(k, rope) # (B, T, nh, hs)
72120

121+
# now `key`, 'query` and `value` tensors are correctly represented: for each element in a batch (B)
122+
# there is a number of heads (nh) and for each head there is a sequence of elements (T), each of them is
123+
# represented by a vector of size `hs`
73124
k = k.transpose(1, 2) # (B, nh, T, hs)
74125
q = q.transpose(1, 2) # (B, nh, T, hs)
75126
v = v.transpose(1, 2) # (B, nh, T, hs)
76127

77128
if kv_cache is not None:
78-
cache_k, cache_v = kv_cache
129+
cache_k, cache_v = kv_cache # 2 * (B, nh, max_seq_length, hs)
79130
# check if reached token limit
80131
if input_pos[-1] >= max_seq_length:
132+
# if we reached token limit and thus there is no space to put newly calculated `key` and `value`
133+
# right next to cached ones, we need to rotate cache tensor along `max_seq_length` dimension by one
134+
# element to the left: this will free up space for new `key` and `value`
81135
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
82136
# shift 1 position to the left
83137
cache_k = torch.roll(cache_k, -1, dims=2)
84138
cache_v = torch.roll(cache_v, -1, dims=2)
85-
k = cache_k.index_copy(2, input_pos, k)
86-
v = cache_v.index_copy(2, input_pos, v)
139+
k = cache_k.index_copy(2, input_pos, k) # (B, nh, max_seq_length, hs)
140+
v = cache_v.index_copy(2, input_pos, v) # (B, nh, max_seq_length, hs)
87141
kv_cache = k, v
88142

89143
# efficient attention using Flash Attention CUDA kernels
90-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
144+
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
145+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs)
91146

147+
# "Adapters are applied to the topmost layers to better tune the language
148+
# representations with higher-level semantics".
92149
if self.block_idx >= self.adapter_start_layer:
93150
if adapter_kv_cache is not None:
94-
ak, av = adapter_kv_cache
151+
ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
95152
else:
96153
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
97154
aT = prefix.size(1)
98-
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2)
99-
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
100-
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
155+
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
156+
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
157+
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
101158
adapter_kv_cache = (ak, av)
102159

103-
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
104-
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
160+
# Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
161+
# obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
162+
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) # (T, aT)
163+
# ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
164+
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) # (B, nh, T, hs)
105165
y = y + self.gating_factor * ay
106166

107167
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
108168

109169
# output projection
110-
y = self.c_proj(y)
170+
y = self.c_proj(y) # (B, T, C)
111171

112172
return y, kv_cache, adapter_kv_cache
113173

@@ -202,9 +262,9 @@ def forward(
202262
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"
203263

204264
if self.rope_cache is None:
205-
self.rope_cache = self.build_rope_cache(idx)
265+
self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2)
206266
if self.mask_cache is None:
207-
self.mask_cache = self.build_mask_cache(idx)
267+
self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)
208268

209269
if input_pos is not None:
210270
rope = self.rope_cache.index_select(0, input_pos)
@@ -215,7 +275,7 @@ def forward(
215275
mask = self.mask_cache[:, :, :T, :T]
216276

217277
# forward the model itself
218-
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
278+
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
219279

220280
if input_pos is None: # proxy for use_cache=False
221281
for block in self.transformer.h:
@@ -235,9 +295,9 @@ def forward(
235295
x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
236296
)
237297

238-
x = self.transformer.ln_f(x)
298+
x = self.transformer.ln_f(x) # (B, T, n_embd)
239299

240-
logits = self.lm_head(x) # (b, t, vocab_size)
300+
logits = self.lm_head(x) # (B, T, vocab_size)
241301

242302
return logits
243303

0 commit comments

Comments
 (0)