@@ -259,27 +259,30 @@ def forward(
259259 k = self .k_proj (y )
260260 v = self .v_proj (y )
261261
262+ # Apply positional embeddings
263+ # k: [b, s_y, n_kv, h_d]
264+ k = k .view (b , s_y , self .num_kv_heads , self .head_dim )
265+ if self .pos_embeddings is not None :
266+ k = self .pos_embeddings (k , input_pos = input_pos )
267+
268+ # View + expand + reshape bring num_kv_heads to num_heads for k and v
269+ # to match q.
270+
262271 # k: [b, s_y, n_kv, 1, h_d]
263272 # v: [b, s_y, n_kv, 1, h_d]
264273 k = k .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
265274 v = v .view (b , s_y , self .num_kv_heads , 1 , self .head_dim )
266275
267- # if needed, expand the key and value tensors to have the same shape
276+ # Expand the key and value tensors to have the same shape
268277 # as the query tensor by copying values across the relevant dim
269278 if self .num_heads != self .num_kv_heads :
270279 k = k .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
271280 v = v .expand (b , s_y , self .num_kv_heads , q_per_kv , self .head_dim )
272281
273- # llama applies the RoPE embeddings on tensors with shape
274282 # [b, s, n_h, h_d]
275- # Reshape the tensors before we apply RoPE
276283 k = k .reshape (b , s_y , - 1 , self .head_dim )
277284 v = v .reshape (b , s_y , - 1 , self .head_dim )
278285
279- # Apply positional embeddings
280- if self .pos_embeddings is not None :
281- k = self .pos_embeddings (k , input_pos = input_pos )
282-
283286 # [b, n_h, s, h_d]
284287 k = k .transpose (1 , 2 )
285288 v = v .transpose (1 , 2 )
0 commit comments