Skip to content

Commit 561dd68

Browse files
xxyuxumiswing
andauthored
add sink_attention (#2461)
Co-authored-by: umiswing <[email protected]>
1 parent 3e2484a commit 561dd68

File tree

6 files changed

+866
-32
lines changed

6 files changed

+866
-32
lines changed

paddleformers/nn/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
"interface": ["AttentionInterface", "ALL_ATTENTION_FUNCTIONS"],
2525
"sdpa_attention": ["sdpa_attention_forward"],
2626
"utils": ["repeat_kv"],
27+
"sink_impl": ["sink_attention_forward"],
2728
}
2829

2930
if TYPE_CHECKING:
3031
from .eager_attention import *
3132
from .flashmask_attention import *
3233
from .interface import *
3334
from .sdpa_attention import *
35+
from .sink_impl import *
3436
from .utils import *
3537
else:
3638
sys.modules[__name__] = _LazyModule(

paddleformers/nn/attention/eager_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def eager_attention_forward(
4040
query = paddle.transpose(x=query, perm=perm)
4141
key = paddle.transpose(x=key, perm=perm)
4242
value = paddle.transpose(x=value, perm=perm)
43+
4344
attn_weights = paddle.matmul(query, key.transpose([0, 1, 3, 2])) * scaling
4445
if attention_mask is not None:
4546
causal_mask = attention_mask[:, :, :, : key.shape[-2]]

paddleformers/nn/attention/flashmask_attention.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,41 @@
1818
import paddle.nn as nn
1919
from paddle.nn.functional.flash_attention import flashmask_attention
2020

21+
from .sink_impl import sink_attention_forward
22+
2123

2224
def flashmask_attention_forward(
2325
module: nn.Layer,
2426
query: paddle.Tensor,
2527
key: paddle.Tensor,
2628
value: paddle.Tensor,
27-
attention_mask: Optional[paddle.Tensor] = None,
28-
attn_mask_start_row_indices=None,
29+
attn_mask_start_row_indices: paddle.Tensor,
2930
dropout: float = 0.0,
31+
sink: Optional[paddle.Tensor] = None,
3032
scaling: Optional[float] = None,
3133
is_causal: Optional[bool] = None,
3234
**kwargs
3335
):
34-
if attn_mask_start_row_indices is not None:
35-
attn_mask_start_row_indices = attn_mask_start_row_indices.unsqueeze(-1)
36-
3736
# b,l,h,d
38-
out = flashmask_attention(
39-
query,
40-
key,
41-
value,
42-
startend_row_indices=attn_mask_start_row_indices,
43-
causal=True,
44-
)
37+
if sink is None:
38+
out = flashmask_attention(
39+
query,
40+
key,
41+
value,
42+
startend_row_indices=attn_mask_start_row_indices,
43+
causal=True,
44+
)
45+
else:
46+
out = sink_attention_forward(
47+
query,
48+
key,
49+
value,
50+
sink,
51+
startend_row_indices=attn_mask_start_row_indices,
52+
dropout_p=dropout,
53+
softmax_scale=scaling,
54+
causal=is_causal,
55+
)
4556
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
4657

4758
return out, None

paddleformers/nn/attention/sdpa_attention.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import paddle.nn as nn
1919

2020
from ...utils.masking_utils import _gen_from_sparse_attn_mask_indices
21+
from .sink_impl import sink_attention_forward
2122

2223

2324
def sdpa_attention_forward(
@@ -28,20 +29,33 @@ def sdpa_attention_forward(
2829
attention_mask: Optional[paddle.Tensor] = None,
2930
attn_mask_start_row_indices=None,
3031
dropout: float = 0.0,
32+
sink: Optional[paddle.Tensor] = None,
3133
scaling: Optional[float] = None,
3234
is_causal: Optional[bool] = None,
3335
**kwargs,
3436
):
3537
# query: b l h d
36-
3738
if is_causal is None and attn_mask_start_row_indices is None:
3839
is_causal = query.shape[1] > 1 and attention_mask is None and getattr(module, "is_causal", True)
3940
elif attn_mask_start_row_indices is not None:
4041
is_causal = False
4142
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, query.dtype)
4243

43-
attn_output = nn.functional.scaled_dot_product_attention(
44-
query, key, value, attention_mask, dropout, is_causal=is_causal, training=module.training
45-
)
44+
if sink is None:
45+
attn_output = nn.functional.scaled_dot_product_attention(
46+
query, key, value, attention_mask, dropout, is_causal=is_causal, training=module.training
47+
)
48+
else:
49+
attn_output = sink_attention_forward(
50+
query,
51+
key,
52+
value,
53+
sink,
54+
attention_mask=attention_mask,
55+
startend_row_indices=None,
56+
dropout_p=dropout,
57+
softmax_scale=scaling,
58+
causal=is_causal,
59+
)
4660
attn_output = paddle.reshape(x=attn_output, shape=[0, 0, attn_output.shape[2] * attn_output.shape[3]])
4761
return attn_output, None

0 commit comments

Comments
 (0)