Skip to content

Commit 1203b21

Browse files
implement dot_product_attention (#455)
* add dot_product_attention * run tests * docs * address some review comments * fix tests * fix fdrop * additional method * bias is positional argument * test bias * fix tests on julia 1.6 * typos * improve docs * remove :causal * Update src/attention.jl * add function barrier
1 parent 2ef2daa commit 1203b21

File tree

7 files changed

+249
-3
lines changed

7 files changed

+249
-3
lines changed

docs/src/reference.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ tanhshrink
3333
trelu
3434
```
3535

36+
## Attention
37+
38+
```@docs
39+
dot_product_attention
40+
dot_product_attention_scores
41+
make_causal_mask
42+
```
43+
3644
## Softmax
3745

3846
`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.

src/NNlib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ for f in ACTIVATIONS
4141
end
4242
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases
4343

44+
include("attention.jl")
45+
export dot_product_attention, dot_product_attention_scores, make_causal_mask
46+
4447
include("dropout.jl")
4548
export dropout, dropout!
4649

src/attention.jl

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
const AA3{T} = AbstractArray{T,3}
2+
const AA4{T} = AbstractArray{T,4}
3+
const AA{N,T} = AbstractArray{T,N}
4+
5+
"""
6+
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
7+
8+
Multihead dot product attention used in transformer architectures.
9+
10+
The input arrays must have the first two dimensions given by the number of features
11+
and the sequence length, then an arbitrary number of batch dimensions or none.
12+
13+
14+
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores
15+
of size `(kv_len, q_len, nheads, batch_size...)`.
16+
17+
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
18+
19+
# Arguments
20+
21+
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
22+
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
23+
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
24+
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
25+
It will be added to the attention scores before applying the softmax. Default `nothing`.
26+
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
27+
Default `identity` (no dropout).
28+
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
29+
The mask is applied to the attention scores just before the softmax.
30+
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
31+
- `nheads`: Number of heads to split the input arrays into. Default `1`.
32+
33+
# Examples
34+
35+
```julia
36+
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
37+
y, α = dot_product_attention(q, k, v)
38+
```
39+
"""
40+
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
41+
batch_size = size(q)[3:end]
42+
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
43+
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
44+
45+
x, α = dot_product_attention(q, k, v, args...; kws...)
46+
47+
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
48+
α = reshape(α, size(α)[1:3]..., batch_size...)
49+
return x, α
50+
end
51+
52+
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
53+
fdrop=identity, mask=nothing, nheads=1)
54+
55+
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
56+
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
57+
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))
58+
59+
# Multihead attention. TODO create fastpath for singlehead attention.
60+
q, k, v = split_heads.((q, k, v), nheads)
61+
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
62+
return join_heads(x), α
63+
end
64+
65+
function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
66+
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
67+
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
68+
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
69+
70+
α = dot_product_attention_scores(q, k, bias; fdrop, mask)
71+
# [α] = [kv_len, q_len, nheads, batch_size]
72+
73+
# The following permutedims and batched_mul are equivalent to
74+
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
75+
vt = permutedims(v, (1, 3, 2, 4))
76+
x = batched_mul(vt, α)
77+
x = permutedims(x, (1, 3, 2, 4))
78+
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
79+
return x, α
80+
end
81+
82+
"""
83+
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
84+
85+
Return the attention scores for the [`dot_product_attention`](@ref).
86+
Input arrays must have dimensions
87+
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
88+
89+
See [`dot_product_attention`](@ref) for more details.
90+
"""
91+
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
92+
fdrop=identity, mask=nothing) where T
93+
94+
# The following permutedims and batched_mul are equivalent to
95+
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
96+
kt = permutedims(k, (3, 1, 2, 4))
97+
qt = permutedims(q, (1, 3, 2, 4)) ./ T(size(q, 1))
98+
logits = batched_mul(kt, qt)
99+
# [logits] = [kv_len, q_len, nheads, batch_size]
100+
101+
logits = apply_attn_bias(logits, bias)
102+
logits = apply_attn_mask(logits, mask)
103+
104+
α = softmax(logits, dims=1)
105+
return fdrop(α)
106+
end
107+
108+
apply_attn_bias(logits, bias::Nothing) = logits
109+
110+
apply_attn_bias(logits, bias) = logits .+ bias
111+
112+
113+
apply_attn_mask(logits, mask::Nothing) = logits
114+
115+
function apply_attn_mask(logits, mask)
116+
neginf = typemin(eltype(logits))
117+
ifelse.(mask, logits, neginf)
118+
end
119+
120+
121+
"""
122+
make_causal_mask(x, dims=2)
123+
124+
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
125+
Its elements are set such that `m[i, j] == i ≤ j`.
126+
127+
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
128+
"""
129+
function make_causal_mask(x::AbstractArray; dims::Int=2)
130+
len = size(x, dims)
131+
mask = triu(trues_like(x, (len, len)))
132+
return mask
133+
end
134+
135+
trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
136+
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)
137+
138+
split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
139+
join_heads(x) = reshape(x, :, size(x)[3:end]...)
140+
141+
@non_differentiable make_causal_mask(::Any...)
142+
@non_differentiable trues_like(::Any...)
143+
@non_differentiable falses_like(::Any...)
144+

src/batched/batchedmul.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A)
55
batched_mul(A, B) -> C
66
A ⊠ B # \\boxtimes
77
8-
Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
9-
If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
8+
Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent
9+
any indices in the last dimensions.
10+
11+
If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
1012
1113
To transpose each matrix, apply `batched_transpose` to the array,
1214
or `batched_adjoint` for conjugate-transpose:
@@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`.
4244
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
4345
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
4446
"""
47+
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
48+
batch_size = size(x)[3:end]
49+
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
50+
x2 = reshape(x, size(x, 1), size(x, 2), :)
51+
y2 = reshape(y, size(y, 1), size(y, 2), :)
52+
z = batched_mul(x2, y2)
53+
return reshape(z, size(z, 1), size(z, 2), batch_size...)
54+
end
55+
4556
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
4657
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
4758
throw(DimensionMismatch("batch size mismatch: A != B"))

src/gemm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings
138138

139139
end
140140

141-
C
141+
return C
142142
end
143143
end
144144
end

test/attention.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
@testset "different batchsizes" begin
2+
n = 15
3+
lenq = 3
4+
lenkv = 4
5+
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]
6+
q = rand(Float32, n, lenq, batch_size...)
7+
k = rand(Float32, n, lenkv, batch_size...)
8+
v = rand(Float32, n, lenkv, batch_size...)
9+
y, α = dot_product_attention(q, k, v; nheads)
10+
@test y isa Array{Float32}
11+
@test size(y) == (n, lenq, batch_size...)
12+
@test size(α) == (lenkv, lenq, nheads, batch_size...)
13+
@test sum(α, dims=1) ones(1, lenq, nheads, batch_size...)
14+
end
15+
end
16+
17+
@testset "dot_product_attention_scores" begin
18+
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
19+
α = dot_product_attention_scores(q, k)
20+
q2, k2 = reshape.((q, k), 8, 3, 1)
21+
y, α2 = dot_product_attention(q2, k2, k2; nheads=2)
22+
@test α α2
23+
end
24+
25+
@testset "specific results" begin
26+
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
27+
y, α = dot_product_attention(q, k, v; nheads=2)
28+
ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788]
29+
ytrue = reshape(ytrue, 4, 3, 1)
30+
αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801]
31+
αtrue = reshapetrue, 3, 3, 2, 1)
32+
@test y ytrue atol=1e-5
33+
@test α αtrue atol=1e-5
34+
end
35+
36+
@testset "mask" begin
37+
q = rand(4, 2, 3, 1)
38+
k = rand(4, 2, 5, 1)
39+
40+
mask = rand(Bool, (5, 3))
41+
α = dot_product_attention_scores(q, k; mask)
42+
@test all((α[:,:,1,1].> 0) .== mask)
43+
@test all((α[:,:,2,1].> 0) .== mask)
44+
45+
@testset "causal" begin
46+
x = rand(4, 2, 3, 1)
47+
mask = make_causal_mask(x, dims=3)
48+
α = dot_product_attention_scores(x, x; mask)
49+
@test all((α[:,:,1,1].> 0) .== mask)
50+
@test all((α[:,:,2,1].> 0) .== mask)
51+
end
52+
end
53+
54+
@testset "dropout" begin
55+
q = k = v = rand(10, 10, 10)
56+
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
57+
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
58+
@test 0.6 > mean(>(0), α) > 0.4
59+
end
60+
61+
@testset "bias" begin
62+
q = rand(4, 5, 1)
63+
k = v = rand(4, 3, 1)
64+
bias = randn(3, 5)
65+
y, α = dot_product_attention(q, k, v, bias; nheads=2)
66+
@test size(α) == (3, 5, 2, 1)
67+
@test size(y) == (4, 5, 1)
68+
end
69+
70+
@testset "gradient" begin
71+
q = rand(4, 5, 1)
72+
k = v = rand(4, 3, 1)
73+
bias = randn(3, 5)
74+
y, α = dot_product_attention(q, k, v, bias; nheads=2)
75+
gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias)
76+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ include("test_utils.jl")
4040
include("activations.jl")
4141
end
4242

43+
@testset "Attention" begin
44+
include("attention.jl")
45+
end
46+
4347
@testset "Batched Multiplication" begin
4448
include("batchedmul.jl")
4549
end

0 commit comments

Comments
 (0)