@@ -37,19 +37,26 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) w
37
37
38
38
q, k, v = map (x -> reshape (x, size (x, 1 ), size (x, 2 ), :), (q, k, v))
39
39
40
- # Multihead attention. TODO create fastpath for singlehead attention.
41
- q, k, v = split_heads .((q, k, v), nheads)
42
- x, α = _dot_product_attention (q, k, v; kws... )
43
- x = join_heads (x)
40
+ x, α = dot_product_attention (q, k, v; nheads, kws... )
44
41
45
42
x = reshape (x, size (x, 1 ), size (x, 2 ), batch_size... )
46
43
α = reshape (α, size (α)[1 : 3 ]. .. , batch_size... )
47
44
return x, α
48
45
end
49
46
50
- function _dot_product_attention (q:: AA4 , k:: AA4 , v:: AA4 ;
51
- fdrop= identity, bias= nothing , mask= nothing )
47
+ function dot_product_attention (q:: AA3 , k:: AA3 , v:: AA3 ; nheads= 1 , kws... )
48
+ # Multihead attention. TODO create fastpath for singlehead attention.
49
+ q, k, v = split_heads .((q, k, v), nheads)
50
+ x, α = _dot_product_attention (q, k, v; kws... )
51
+ return join_heads (x), α
52
+ end
52
53
54
+ function _dot_product_attention (q:: AA4 , k:: AA4 , v:: AA4 ;
55
+ fdrop= identity, bias= nothing , mask= nothing )
56
+ # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
57
+ # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
58
+ # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
59
+
53
60
α = dot_product_attention_scores (q, k; fdrop, bias, mask)
54
61
# [α] = [kv_len, q_len, nheads, batch_size]
55
62
0 commit comments