22
33LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
44https://arxiv.org/abs/2303.16199
5+
6+ | Prefix cross-attention
7+ |
8+ ┌─────────────────┐ | ┌──────────────────┐
9+ ┆ x ┆ | ┆ prefix ┆
10+ └─────────────────┘ | └──────────────────┘
11+ | | |
12+ ▼ | ▼
13+ ┌──────────────────┐ | ┌─────────────────────┐
14+ ┆ self-attention ┆ --------------------------------------------------------------┐ ┆ linear projection ┆
15+ └──────────────────┘ | ┆ └─────────────────────┘
16+ | | ┆ | \
17+ ▼ | ▼ ▼ ▼
18+ ╭───╮ ┌────────────────┐ ╭───╮ ┌──────────────────────────┐ | ┌─────────┐ ┌──────────────┐ ┌────────────────┐
19+ ┆ + ┆ ◀── ┆ gating factor ┆-┆ x ┆-┆ prefix cross-attention ┆ | ┆ query ┆ ┆ prefix key ┆ ┆ prefix value ┆
20+ ╰───╯ └────────────────┘ ╰───╯ └──────────────────────────┘ | └─────────┘ └──────────────┘ └────────────────┘
21+ | | \ | /
22+ ▼ | ▼ ▼ ▼
23+ | ┌────────────────────────────────┐
24+ | ┆ scaled dot-product attention ┆
25+ | └────────────────────────────────┘
26+
27+
28+ In order to inject learnable information from the prefix to pretrained weights we need to sum outputs from
29+ self-attention and prefix cross-attention (times gating factor). For prefix cross-attention we need `query` (from
30+ self-attention as a result of linear projection), `prefix key` and `prefix value` (from cross-attention as a result of
31+ linear projection).
32+ The output of prefix cross-attention is multiplied by gating factor, which is a learnable parameter that is needed to
33+ avoid potential disruption of pretrained weights caused by incorporating randomly initialized tensors. This factor is
34+ initialized with zeros to avoid noise from the adaption prompts at the early training stage.
35+ More about it: https://lightning.ai/pages/community/article/understanding-llama-adapters/
36+
37+ Notes about implementation: as per paper adapter's prefix is concatenated with the input, while here outputs of
38+ self-attention and prefix cross-attention are summed. Both variants are mathematically equivalent:
39+ https://github.com/ZrrSkywalker/LLaMA-Adapter/issues/47
540"""
641# mypy: ignore-errors
742from dataclasses import dataclass
@@ -37,7 +72,8 @@ def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
3772 if block_idx >= config .adapter_start_layer :
3873 # adapter embedding layer
3974 self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
40- # gate for adaption
75+ # a learnable gating factor (to avoid potential disruption of pretrained weights) initialized with zeros (to
76+ # avoid noise from adaption prompts at the early training stage)
4177 self .gating_factor = torch .nn .Parameter (torch .zeros (1 , config .n_head , 1 , 1 ))
4278
4379 self .n_head = config .n_head
@@ -57,57 +93,81 @@ def forward(
5793 kv_cache : Optional [KVCache ] = None ,
5894 adapter_kv_cache : Optional [KVCache ] = None ,
5995 ) -> Tuple [torch .Tensor , Optional [KVCache ], Optional [KVCache ]]:
60- B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
61-
62- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
63- q , k , v = self .c_attn (x ).split (self .n_embd , dim = 2 )
64-
96+ # notation:
97+ # - B | batch
98+ # - T | time-step (sequence length)
99+ # - C | embeddings size (n_embd) = head size * num heads
100+ # - hs | head size
101+ # - nh | number of heads
102+
103+ B , T , C = x .size ()
104+
105+ # instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
106+ # weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
107+ # weights for q, k, v) and then split the result along `embedding size` dimension
108+ q , k , v = self .c_attn (x ).split (self .n_embd , dim = 2 ) # (B, T, 3 * C) --> 3 * (B, T, C)
109+
110+ # in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
111+ # embedding size (C) dimension into num_heads (nh) and head_size (hs)
65112 head_size = C // self .n_head
66113 k = k .view (B , T , self .n_head , head_size )
67114 q = q .view (B , T , self .n_head , head_size )
68115 v = v .view (B , T , self .n_head , head_size )
69116
70- q = apply_rope (q , rope )
71- k = apply_rope (k , rope )
117+ # "Unlike standard positional embeddings rotary embeddings must be applied at every layer"
118+ q = apply_rope (q , rope ) # (B, T, nh, hs)
119+ k = apply_rope (k , rope ) # (B, T, nh, hs)
72120
121+ # now `key`, 'query` and `value` tensors are correctly represented: for each element in a batch (B)
122+ # there is a number of heads (nh) and for each head there is a sequence of elements (T), each of them is
123+ # represented by a vector of size `hs`
73124 k = k .transpose (1 , 2 ) # (B, nh, T, hs)
74125 q = q .transpose (1 , 2 ) # (B, nh, T, hs)
75126 v = v .transpose (1 , 2 ) # (B, nh, T, hs)
76127
77128 if kv_cache is not None :
78- cache_k , cache_v = kv_cache
129+ cache_k , cache_v = kv_cache # 2 * (B, nh, max_seq_length, hs)
79130 # check if reached token limit
80131 if input_pos [- 1 ] >= max_seq_length :
132+ # if we reached token limit and thus there is no space to put newly calculated `key` and `value`
133+ # right next to cached ones, we need to rotate cache tensor along `max_seq_length` dimension by one
134+ # element to the left: this will free up space for new `key` and `value`
81135 input_pos = torch .tensor (max_seq_length - 1 , device = input_pos .device )
82136 # shift 1 position to the left
83137 cache_k = torch .roll (cache_k , - 1 , dims = 2 )
84138 cache_v = torch .roll (cache_v , - 1 , dims = 2 )
85- k = cache_k .index_copy (2 , input_pos , k )
86- v = cache_v .index_copy (2 , input_pos , v )
139+ k = cache_k .index_copy (2 , input_pos , k ) # (B, nh, max_seq_length, hs)
140+ v = cache_v .index_copy (2 , input_pos , v ) # (B, nh, max_seq_length, hs)
87141 kv_cache = k , v
88142
89143 # efficient attention using Flash Attention CUDA kernels
90- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
144+ # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
145+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 ) # (B, nh, T, hs)
91146
147+ # "Adapters are applied to the topmost layers to better tune the language
148+ # representations with higher-level semantics".
92149 if self .block_idx >= self .adapter_start_layer :
93150 if adapter_kv_cache is not None :
94- ak , av = adapter_kv_cache
151+ ak , av = adapter_kv_cache # 2 * (B, nh, aT, hs)
95152 else :
96153 prefix = self .adapter_wte .weight .reshape (1 , self .adapter_prompt_length , self .n_embd )
97154 aT = prefix .size (1 )
98- _ , ak , av = self .c_attn (prefix ).split (self .n_embd , dim = 2 )
99- ak = ak .view (1 , aT , self .n_head , head_size ).repeat (B , 1 , 1 , 1 ).transpose (1 , 2 )
100- av = av .view (1 , aT , self .n_head , head_size ).repeat (B , 1 , 1 , 1 ).transpose (1 , 2 )
155+ _ , ak , av = self .c_attn (prefix ).split (self .n_embd , dim = 2 ) # (1, aT, 3 * C) --> 3 * (1, aT, C)
156+ ak = ak .view (1 , aT , self .n_head , head_size ).repeat (B , 1 , 1 , 1 ).transpose (1 , 2 ) # (B, nh, aT, hs)
157+ av = av .view (1 , aT , self .n_head , head_size ).repeat (B , 1 , 1 , 1 ).transpose (1 , 2 ) # (B, nh, aT, hs)
101158 adapter_kv_cache = (ak , av )
102159
103- amask = torch .ones (q .shape [- 2 ], ak .shape [- 2 ], dtype = torch .bool , device = x .device )
104- ay = F .scaled_dot_product_attention (q , ak , av , attn_mask = amask , dropout_p = 0.0 , is_causal = False )
160+ # Apply cross-attention with `query`, `adapter_key`, `adapter_value` and sum the output with the output
161+ # obtained from self-attention step. This is mathematically equivalent to concatenation of prefix and input as per paper.
162+ amask = torch .ones (q .shape [- 2 ], ak .shape [- 2 ], dtype = torch .bool , device = x .device ) # (T, aT)
163+ # ↓ (B, nh, T, hs) @ (B, nh, aT, hs).mT --> (B, nh, T, aT) @ (B, nh, aT, hs) --> (B, nh, T, hs)
164+ ay = F .scaled_dot_product_attention (q , ak , av , attn_mask = amask , dropout_p = 0.0 , is_causal = False ) # (B, nh, T, hs)
105165 y = y + self .gating_factor * ay
106166
107167 y = y .transpose (1 , 2 ).contiguous ().view (B , T , C ) # re-assemble all head outputs side by side
108168
109169 # output projection
110- y = self .c_proj (y )
170+ y = self .c_proj (y ) # (B, T, C)
111171
112172 return y , kv_cache , adapter_kv_cache
113173
@@ -202,9 +262,9 @@ def forward(
202262 assert T <= block_size , f"Cannot forward sequence of length { T } , block size is only { block_size } "
203263
204264 if self .rope_cache is None :
205- self .rope_cache = self .build_rope_cache (idx )
265+ self .rope_cache = self .build_rope_cache (idx ) # (block_size, head_size / 2, 2)
206266 if self .mask_cache is None :
207- self .mask_cache = self .build_mask_cache (idx )
267+ self .mask_cache = self .build_mask_cache (idx ) # (1, 1, block_size, block_size)
208268
209269 if input_pos is not None :
210270 rope = self .rope_cache .index_select (0 , input_pos )
@@ -215,7 +275,7 @@ def forward(
215275 mask = self .mask_cache [:, :, :T , :T ]
216276
217277 # forward the model itself
218- x = self .transformer .wte (idx ) # token embeddings of shape (b, t , n_embd)
278+ x = self .transformer .wte (idx ) # token embeddings of shape (B, T , n_embd)
219279
220280 if input_pos is None : # proxy for use_cache=False
221281 for block in self .transformer .h :
@@ -235,9 +295,9 @@ def forward(
235295 x , rope , mask , max_seq_length , input_pos , self .kv_caches [i ], self .adapter_kv_caches [i ]
236296 )
237297
238- x = self .transformer .ln_f (x )
298+ x = self .transformer .ln_f (x ) # (B, T, n_embd)
239299
240- logits = self .lm_head (x ) # (b, t , vocab_size)
300+ logits = self .lm_head (x ) # (B, T , vocab_size)
241301
242302 return logits
243303
0 commit comments