@@ -3,7 +3,7 @@ const AA4{T} = AbstractArray{T,4}
3
3
const AA{N,T} = AbstractArray{T,N}
4
4
5
5
"""
6
- dot_product_attention(query, key, value; [bias, droput_fn , mask, num_heads ])
6
+ dot_product_attention(query, key, value; [bias, fdrop , mask, nheads ])
7
7
8
8
Multihead dot product attention used in transformer architectures.
9
9
@@ -15,12 +15,11 @@ and the sequece length, then an arbitrary number of batch dimensions or none.
15
15
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
16
16
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
17
17
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
18
- - `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`.
18
+ - `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
19
+ - `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout).
20
+ - `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
19
21
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
20
- - `dropout_fn`: A dropout function to apply on the attention scores. Default `nothing`.
21
- - `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`.
22
- Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
23
- - `num_heads`: Number of heads to split the input arrays into. Default `1`.
22
+ - `nheads`: Number of heads to split the input arrays into. Default `1`.
24
23
25
24
# Examples
26
25
@@ -29,7 +28,7 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
29
28
y, α = dot_product_attention(q, k, v)
30
29
```
31
30
"""
32
- function dot_product_attention (q:: AA{N} , k:: AA{N} , v:: AA{N} ; num_heads = 1 , kws... ) where N
31
+ function dot_product_attention (q:: AA{N} , k:: AA{N} , v:: AA{N} ; nheads = 1 , kws... ) where N
33
32
batch_size = size (q)[3 : end ]
34
33
35
34
batch_size == size (k)[3 : end ] == size (v)[3 : end ] || throw (ArgumentError (" Batch dimensions have to be the same." ))
@@ -39,7 +38,7 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...
39
38
q, k, v = map (x -> reshape (x, size (x, 1 ), size (x, 2 ), :), (q, k, v))
40
39
41
40
# Multihead attention. TODO create fastpath for singlehead attention.
42
- q, k, v = split_heads .((q, k, v), num_heads )
41
+ q, k, v = split_heads .((q, k, v), nheads )
43
42
x, α = _dot_product_attention (q, k, v; kws... )
44
43
x = join_heads (x)
45
44
@@ -49,17 +48,17 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...
49
48
end
50
49
51
50
function _dot_product_attention (q:: AA4 , k:: AA4 , v:: AA4 ;
52
- dropout_fn = nothing , bias= nothing , mask= nothing )
51
+ fdrop = nothing , bias= nothing , mask= nothing )
53
52
54
- α = dot_product_attention_scores (q, k; dropout_fn , bias, mask)
55
- # [α] = [kv_len, q_len, num_heads , batch_size]
53
+ α = dot_product_attention_scores (q, k; fdrop , bias, mask)
54
+ # [α] = [kv_len, q_len, nheads , batch_size]
56
55
57
56
# The following permutedims and batched_mul are equivalent to
58
57
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
59
58
vt = permutedims (v, (1 , 3 , 2 , 4 ))
60
59
x = batched_mul (vt, α)
61
60
x = permutedims (x, (1 , 3 , 2 , 4 ))
62
- # [x] = [kv_dim ÷ num_heads, num_heads , q_len, batch_size]
61
+ # [x] = [kv_dim ÷ nheads, nheads , q_len, batch_size]
63
62
return x, α
64
63
end
65
64
68
67
69
68
Return the attention scores for the [`dot_product_attention`](@ref).
70
69
71
- Input arrays must have dimensions `(num_features ÷ num_heads, num_heads , sequence_length, batch_size)`
70
+ Input arrays must have dimensions `(num_features ÷ nheads, nheads , sequence_length, batch_size)`.
72
71
73
72
"""
74
73
function dot_product_attention_scores (q:: AA4{T} , k:: AA4{T} ;
75
- dropout_fn = nothing , mask= nothing , bias= nothing ) where T
74
+ fdrop = identity , mask= nothing , bias= nothing ) where T
76
75
77
- q = q ./ √ T (size (q, 1 ))
78
-
79
76
# The following permutedims and batched_mul are equivalent to
80
- # @tullio α [j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
77
+ # @tullio logits [j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
81
78
kt = permutedims (k, (3 , 1 , 2 , 4 ))
82
- qt = permutedims (q, (1 , 3 , 2 , 4 ))
83
- α = batched_mul (kt, qt)
84
- # [α ] = [kv_len, q_len, num_heads , batch_size]
79
+ qt = permutedims (q, (1 , 3 , 2 , 4 )) ./ √ T ( size (q, 1 ))
80
+ logits = batched_mul (kt, qt)
81
+ # [logits ] = [kv_len, q_len, nheads , batch_size]
85
82
86
83
if bias != = nothing
87
- α = α .+ bias
84
+ logits = logits .+ bias
88
85
end
89
86
90
87
if mask != = nothing
91
88
if mask === :causal
92
- mask = make_causal_mask (α )
89
+ mask = make_causal_mask (logits )
93
90
end
94
- neginf = typemin (eltype (α ))
95
- α = ifelse .(mask, α , neginf)
91
+ neginf = typemin (eltype (logits ))
92
+ logits = ifelse .(mask, logits , neginf)
96
93
end
97
94
98
- α = softmax (α , dims= 1 )
99
- return dropout_fn === nothing ? α : dropout_fn (α)
95
+ α = softmax (logits , dims= 1 )
96
+ return fdrop (α)
100
97
end
101
98
102
99
"""
116
113
trues_like (x:: AbstractArray , sz= size (x)) = fill! (similar (x, Bool, sz), true )
117
114
falses_like (x:: AbstractArray , sz= size (x)) = fill! (similar (x, Bool, sz), false )
118
115
119
- split_heads (x, num_heads ) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads , size (x)[2 : end ]. .. )
116
+ split_heads (x, nheads ) = reshape (x, size (x, 1 ) ÷ nheads, nheads , size (x)[2 : end ]. .. )
120
117
join_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
121
118
122
119
@non_differentiable make_causal_mask (x)
0 commit comments