Skip to content

Commit 5a5c58b

Browse files
additional method
1 parent 4d5a6d9 commit 5a5c58b

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/attention.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,26 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) w
3737

3838
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
3939

40-
# Multihead attention. TODO create fastpath for singlehead attention.
41-
q, k, v = split_heads.((q, k, v), nheads)
42-
x, α = _dot_product_attention(q, k, v; kws...)
43-
x = join_heads(x)
40+
x, α = dot_product_attention(q, k, v; nheads, kws...)
4441

4542
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
4643
α = reshape(α, size(α)[1:3]..., batch_size...)
4744
return x, α
4845
end
4946

50-
function _dot_product_attention(q::AA4, k::AA4, v::AA4;
51-
fdrop=identity, bias=nothing, mask=nothing)
47+
function dot_product_attention(q::AA3, k::AA3, v::AA3; nheads=1, kws...)
48+
# Multihead attention. TODO create fastpath for singlehead attention.
49+
q, k, v = split_heads.((q, k, v), nheads)
50+
x, α = _dot_product_attention(q, k, v; kws...)
51+
return join_heads(x), α
52+
end
5253

54+
function _dot_product_attention(q::AA4, k::AA4, v::AA4;
55+
fdrop=identity, bias=nothing, mask=nothing)
56+
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
57+
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
58+
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
59+
5360
α = dot_product_attention_scores(q, k; fdrop, bias, mask)
5461
# [α] = [kv_len, q_len, nheads, batch_size]
5562

0 commit comments

Comments
 (0)