@@ -3,21 +3,28 @@ const AA4{T} = AbstractArray{T,4}
3
3
const AA{N,T} = AbstractArray{T,N}
4
4
5
5
"""
6
- dot_product_attention(query, key, value; [bias, fdrop, mask, nheads])
6
+ dot_product_attention(query, key, value [bias]; fdrop, mask, nheads])
7
7
8
8
Multihead dot product attention used in transformer architectures.
9
9
10
10
The input arrays must have the first two dimensions given by the number of features
11
- and the sequece length, then an arbitrary number of batch dimensions or none.
11
+ and the sequece length, then an arbitrary number of batch dimensions or none.
12
+
13
+ Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores.
14
+ of size `(kv_len, q_len, nheads, batch_size...)`.
15
+
16
+ See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
12
17
13
18
# Arguments
14
19
15
20
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
16
21
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
17
22
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
18
- - `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
23
+ - `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
24
+ It will be added to the attention scores before applying the softmax. Default `nothing`.
19
25
- `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout).
20
- - `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
26
+ - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
27
+ The mask be applied to the attention scores before applying the softmax.
21
28
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
22
29
- `nheads`: Number of heads to split the input arrays into. Default `1`.
23
30
@@ -28,36 +35,37 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
28
35
y, α = dot_product_attention(q, k, v)
29
36
```
30
37
"""
31
- function dot_product_attention (q:: AA{N} , k:: AA{N} , v:: AA{N} ; nheads = 1 , kws... ) where N
38
+ function dot_product_attention (q:: AA{N} , k:: AA{N} , v:: AA{N} , args ... ; kws... ) where N
32
39
batch_size = size (q)[3 : end ]
33
-
34
40
batch_size == size (k)[3 : end ] == size (v)[3 : end ] || throw (ArgumentError (" Batch dimensions have to be the same." ))
35
- size (q, 1 ) == size (k, 1 ) || throw (ArgumentError (" First dimension in query and key has to be the same." ))
36
- size (k, 2 ) == size (v, 2 ) || throw (ArgumentError (" Second dimension in key and value has to be the same." ))
37
-
38
41
q, k, v = map (x -> reshape (x, size (x, 1 ), size (x, 2 ), :), (q, k, v))
39
42
40
- x, α = dot_product_attention (q, k, v; nheads, kws... )
43
+ x, α = dot_product_attention (q, k, v, args ... ; kws... )
41
44
42
45
x = reshape (x, size (x, 1 ), size (x, 2 ), batch_size... )
43
46
α = reshape (α, size (α)[1 : 3 ]. .. , batch_size... )
44
47
return x, α
45
48
end
46
49
47
- function dot_product_attention (q:: AA3 , k:: AA3 , v:: AA3 ; nheads= 1 , kws... )
50
+ function dot_product_attention (q:: AA3 , k:: AA3 , v:: AA3 , bias= nothing ;
51
+ fdrop= identity, mask= nothing , nheads= 1 )
52
+
53
+ (size (q, 3 ) == size (k, 3 ) == size (v, 3 )) || throw (ArgumentError (" Batch dimensions have to be the same." ))
54
+ size (q, 1 ) == size (k, 1 ) || throw (ArgumentError (" First dimension in query and key has to be the same." ))
55
+ size (k, 2 ) == size (v, 2 ) || throw (ArgumentError (" Second dimension in key and value has to be the same." ))
56
+
48
57
# Multihead attention. TODO create fastpath for singlehead attention.
49
58
q, k, v = split_heads .((q, k, v), nheads)
50
- x, α = _dot_product_attention (q, k, v; kws ... )
59
+ x, α = _dot_product_attention (q, k, v, bias, fdrop, mask )
51
60
return join_heads (x), α
52
61
end
53
62
54
- function _dot_product_attention (q:: AA4 , k:: AA4 , v:: AA4 ;
55
- fdrop= identity, bias= nothing , mask= nothing )
63
+ function _dot_product_attention (q:: AA4 , k:: AA4 , v:: AA4 , bias, fdrop, mask)
56
64
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
57
65
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
58
66
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
59
-
60
- α = dot_product_attention_scores (q, k; fdrop , bias, mask)
67
+
68
+ α = dot_product_attention_scores (q, k, bias; fdrop , mask)
61
69
# [α] = [kv_len, q_len, nheads, batch_size]
62
70
63
71
# The following permutedims and batched_mul are equivalent to
@@ -70,14 +78,16 @@ function _dot_product_attention(q::AA4, k::AA4, v::AA4;
70
78
end
71
79
72
80
"""
73
- dot_product_attention_scores(query, key; [bias, droput_fn , mask])
81
+ dot_product_attention_scores(query, key, [bias]; [fdrop , mask])
74
82
75
83
Return the attention scores for the [`dot_product_attention`](@ref).
84
+ Input arrays must have dimensions
85
+ `(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
76
86
77
- Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)` .
87
+ See [`dot_product_attention`](@ref) for more details .
78
88
"""
79
- function dot_product_attention_scores (q:: AA4{T} , k:: AA4{T} ;
80
- fdrop= identity, mask= nothing , bias = nothing ) where T
89
+ function dot_product_attention_scores (q:: AA4{T} , k:: AA4{T} , bias = nothing ;
90
+ fdrop= identity, mask= nothing ) where T
81
91
82
92
# The following permutedims and batched_mul are equivalent to
83
93
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
0 commit comments