1717)
1818
1919
20+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
21+ """
22+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
23+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
24+ """
25+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
26+ if n_rep == 1 :
27+ return hidden_states
28+ hidden_states = hidden_states [:, :, None , :, :].expand (
29+ batch , num_key_value_heads , n_rep , slen , head_dim
30+ )
31+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
32+
33+
2034def apply_rotary_emb_single (
2135 x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
2236) -> torch .Tensor :
@@ -59,13 +73,13 @@ def prepare_sha(self):
5973 self .wk_sha = nn .ModuleList (
6074 [
6175 nn .Linear (self .dim , self .head_dim , bias = False )
62- for _ in range (self .n_heads )
76+ for _ in range (self .n_kv_heads )
6377 ]
6478 )
6579 self .wv_sha = nn .ModuleList (
6680 [
6781 nn .Linear (self .dim , self .head_dim , bias = False )
68- for _ in range (self .n_heads )
82+ for _ in range (self .n_kv_heads )
6983 ]
7084 )
7185
@@ -76,6 +90,7 @@ def prepare_sha(self):
7690 self .wq_sha [i ].weight .data .copy_ (
7791 self .wq .weight [i * self .head_dim : (i + 1 ) * self .head_dim ]
7892 )
93+ for i in range (self .n_kv_heads ):
7994 self .wk_sha [i ].weight .data .copy_ (
8095 self .wk .weight [i * self .head_dim : (i + 1 ) * self .head_dim ]
8196 )
@@ -97,30 +112,27 @@ def forward_sha(
97112 v = [wv_sha (hidden_states ) for wv_sha in self .wv_sha ]
98113 for i in range (len (q )):
99114 q [i ] = apply_rotary_emb_single (q [i ], freqs_cos , freqs_sin )
115+ for i in range (len (k )):
100116 k [i ] = apply_rotary_emb_single (k [i ], freqs_cos , freqs_sin ).permute (0 , 2 , 1 )
101117
102- output_kh , output_vh , output_y = [], [], []
118+ output_y = []
119+ kh , vh = [], []
103120 for i , _ in enumerate (k_caches ):
104- # cat at the seq dim
105- kh = torch .cat ([k_caches [i ], k [i ]], dim = - 1 )
106- vh = torch .cat ([v_caches [i ], v [i ]], dim = 1 )
121+ kh .append (torch .cat ([k_caches [i ], k [i ]], dim = - 1 ))
122+ vh .append (torch .cat ([v_caches [i ], v [i ]], dim = 1 ))
107123
108- attn = q [i ] @ kh
124+ for i , _ in enumerate (q ):
125+ cache_idx = i // self .num_key_value_groups
126+ attn = q [i ] @ kh [cache_idx ]
109127 attn = attn / self .scale + atten_mask
110128 attn = self .attn_softmax (attn )
111- y = attn @ vh
129+ y = attn @ vh [ cache_idx ]
112130
113- if self .output_new_cache_only :
114- output_kh .append (k [i ])
115- output_vh .append (v [i ])
116- else :
117- output_kh .append (kh )
118- output_vh .append (vh )
119131 output_y .append (y )
120132
121133 y = torch .concat (output_y , dim = - 1 )
122134 y = self .wo (y )
123- return y , output_kh , output_vh
135+ return y , k , v
124136
125137 def forward (
126138 self ,
@@ -142,24 +154,28 @@ def forward(
142154 k = apply_rotary_emb_single (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
143155
144156 output_kh , output_vh , output_y = [], [], []
145-
157+ kh , vh = [], []
146158 for i , _ in enumerate (k_caches ):
147- # cat at the seq dim
148- kh = torch .cat ([k_caches [i ], k [:, i , :, :]], dim = - 1 )
149- vh = torch .cat ([v_caches [i ], v [:, :, i , :]], dim = 1 )
159+ kh .append (torch .cat ([k_caches [i ], k [:, i , :, :]], dim = - 1 ))
160+ vh .append (torch .cat ([v_caches [i ], v [:, :, i , :]], dim = 1 ))
161+
162+ for i in range (self .n_heads ):
163+ cache_idx = i // self .num_key_value_groups
150164
151- attn = q [:, :, i , :] @ kh
165+ attn = q [:, :, i , :] @ kh [ cache_idx ]
152166 attn = attn / self .scale + atten_mask
153167 attn = self .attn_softmax (attn )
154- y = attn @ vh
168+ y = attn @ vh [ cache_idx ]
155169
170+ output_y .append (y )
171+
172+ for i in range (len (k_caches )):
156173 if self .output_new_cache_only :
157174 output_kh .append (k [:, i , :, :])
158175 output_vh .append (v [:, :, i , :])
159176 else :
160- output_kh .append (kh )
161- output_vh .append (vh )
162- output_y .append (y )
177+ output_kh .append (kh [i ])
178+ output_vh .append (vh [i ])
163179
164180 y = torch .concat (output_y , dim = - 1 )
165181 y = self .wo (y )
@@ -246,10 +262,10 @@ def forward(
246262
247263 hidden_states = self .tok_embeddings (tokens )
248264 for ind , decoder_layer in enumerate (self .layers ):
249- offset_k = ind * self .n_heads
250- offset_v = self .n_layers * self .n_heads + offset_k
251- k_caches = args [offset_k : offset_k + self .n_heads ]
252- v_caches = args [offset_v : offset_v + self .n_heads ]
265+ offset_k = ind * self .n_kv_heads
266+ offset_v = self .n_layers * self .n_kv_heads + offset_k
267+ k_caches = args [offset_k : offset_k + self .n_kv_heads ]
268+ v_caches = args [offset_v : offset_v + self .n_kv_heads ]
253269 hidden_states , k , v = decoder_layer (
254270 hidden_states ,
255271 freqs_cos = freqs_cos ,
@@ -275,7 +291,7 @@ def get_example_inputs(self):
275291 atten_mask = torch .full ((self .max_batch_size , self .max_seq_len ), - 255.0 )
276292 atten_mask [:, - 1 ] = 0
277293 for _ in range (self .n_layers ):
278- for _ in range (self .n_heads ):
294+ for _ in range (self .n_kv_heads ):
279295 # transpose first to decrease the runtime efforts
280296 k_cache .append (
281297 torch .zeros (
@@ -299,40 +315,6 @@ def get_example_inputs(self):
299315 v_cache ,
300316 )
301317
302- def get_export_inputs (self ):
303- tokens = torch .randint (
304- self .vocab_size , (self .max_batch_size , 1 ), dtype = torch .int32
305- )
306- pos_ids = torch .zeros ((self .max_batch_size , 1 ), dtype = torch .int32 )
307- # this is important for torch.export not to take it as dummy input
308- k_cache , v_cache = [], []
309- atten_mask = torch .full ((self .max_batch_size , self .max_seq_len ), - 255.0 )
310- atten_mask [:, - 1 ] = 0
311- for _ in range (self .n_layers ):
312- for _ in range (self .n_heads ):
313- # transpose first to decrease the runtime efforts
314- k_cache .append (
315- torch .randn (
316- self .max_batch_size ,
317- self .head_dim ,
318- self .max_seq_len - 1 ,
319- )
320- )
321- v_cache .append (
322- torch .randn (
323- self .max_batch_size ,
324- self .max_seq_len - 1 ,
325- self .head_dim ,
326- )
327- )
328- return (
329- tokens ,
330- pos_ids ,
331- atten_mask ,
332- k_cache ,
333- v_cache ,
334- )
335-
336318 def get_metadata (self ):
337319 # TODO: modify this when enabling LLAMA 7B
338320 return {
@@ -344,7 +326,7 @@ def get_metadata(self):
344326 "get_max_seq_len" : self .max_seq_len ,
345327 "get_n_bos" : 1 ,
346328 "get_n_eos" : 1 ,
347- "get_n_kv_heads" : self .n_heads ,
329+ "get_n_kv_heads" : self .n_kv_heads ,
348330 "get_n_layers" : self .n_layers ,
349331 "get_vocab_size" : self .vocab_size ,
350332 }
0 commit comments