Skip to content

Commit e61cf56

Browse files
committed
Fixes
1 parent e417acf commit e61cf56

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

torchtune/modules/attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,14 @@ def __init__(
143143
# Use flex attention if supported and we are sample packing
144144
self._attention_call = _sdpa_or_flex_attention()
145145
self._sdpa = SDPA(
146+
num_kv_heads=self.num_kv_heads,
147+
num_heads=self.num_heads,
148+
head_dim=self.head_dim,
149+
q_per_kv=self.q_per_kv,
150+
attn_dropout=self.attn_dropout,
151+
is_causal=self.is_causal,
146152
attention_fn=self._attention_call,
147153
kv_cache=self.kv_cache,
148-
q_per_kv=self.q_per_kv,
149154
)
150155

151156
def setup_cache(

torchtune/modules/sdpa.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
from torch import nn, Tensor
810

911

@@ -17,22 +19,32 @@ class SDPA(nn.Module):
1719

1820
def __init__(
1921
self,
22+
num_kv_heads: int,
23+
num_heads: int,
24+
head_dim: int,
25+
q_per_kv: int,
26+
attn_dropout: float,
27+
is_causal: bool,
2028
attention_fn,
2129
kv_cache,
22-
q_per_kv,
2330
) -> None:
2431
super().__init__()
25-
self._attention_fn = attention_fn
26-
self.kv_cache = kv_cache
32+
self.num_kv_heads = num_kv_heads
33+
self.num_heads = num_heads
34+
self.head_dim = head_dim
2735
self.q_per_kv = q_per_kv
36+
self.attn_dropout = attn_dropout
37+
self.is_causal = is_causal
38+
self._attention_fn = attention_fn
39+
self._kv_cache = kv_cache
2840

2941
def kv_cache_update(
3042
self,
3143
input_pos: Tensor,
3244
k: Tensor,
3345
v: Tensor,
3446
) -> Tuple[Tensor, Tensor]:
35-
k, v = self.kv_cache.update(input_pos, k, v)
47+
k, v = self._kv_cache.update(input_pos, k, v)
3648
return k, v
3749

3850
def sdpa(
@@ -72,7 +84,7 @@ def sdpa(
7284
v,
7385
mask=mask,
7486
dropout_p=self.attn_dropout,
75-
is_causal=self.kv_cache is None and mask is None and self.is_causal,
87+
is_causal=self._kv_cache is None and mask is None and self.is_causal,
7688
)
7789
# Reshape the output to be the same shape as the input
78-
return output.transpose(1, 2).contiguous().view(b, s_x, -1)
90+
return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)

0 commit comments

Comments
 (0)