Skip to content

Commit 75f6975

Browse files
committed
Apply rope on k earlier for efficiency
1 parent d7fae96 commit 75f6975

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torchtune/modules/attention.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,27 +259,30 @@ def forward(
259259
k = self.k_proj(y)
260260
v = self.v_proj(y)
261261

262+
# Apply positional embeddings
263+
# k: [b, s_y, n_kv, h_d]
264+
k = k.view(b, s_y, self.num_kv_heads, self.head_dim)
265+
if self.pos_embeddings is not None:
266+
k = self.pos_embeddings(k, input_pos=input_pos)
267+
268+
# View + expand + reshape bring num_kv_heads to num_heads for k and v
269+
# to match q.
270+
262271
# k: [b, s_y, n_kv, 1, h_d]
263272
# v: [b, s_y, n_kv, 1, h_d]
264273
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
265274
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
266275

267-
# if needed, expand the key and value tensors to have the same shape
276+
# Expand the key and value tensors to have the same shape
268277
# as the query tensor by copying values across the relevant dim
269278
if self.num_heads != self.num_kv_heads:
270279
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
271280
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
272281

273-
# llama applies the RoPE embeddings on tensors with shape
274282
# [b, s, n_h, h_d]
275-
# Reshape the tensors before we apply RoPE
276283
k = k.reshape(b, s_y, -1, self.head_dim)
277284
v = v.reshape(b, s_y, -1, self.head_dim)
278285

279-
# Apply positional embeddings
280-
if self.pos_embeddings is not None:
281-
k = self.pos_embeddings(k, input_pos=input_pos)
282-
283286
# [b, n_h, s, h_d]
284287
k = k.transpose(1, 2)
285288
v = v.transpose(1, 2)

0 commit comments

Comments
 (0)