From bf64ca8360f8c6c9ecd148ed4cf40704814f7e42 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 3 Jan 2023 10:47:46 +0100 Subject: [PATCH 01/15] add dot_product_attention --- src/NNlib.jl | 3 + src/attention.jl | 125 ++++++++++++++++++++++++++++++++++++++ src/batched/batchedmul.jl | 15 ++++- src/gemm.jl | 2 +- test/attention.jl | 39 ++++++++++++ 5 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 src/attention.jl create mode 100644 test/attention.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index acca75299..8be0d01bc 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -41,6 +41,9 @@ for f in ACTIVATIONS end export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases +include("attention.jl") +export dot_product_attention, dot_product_attention_scores, make_causal_mask + include("dropout.jl") export dropout, dropout! diff --git a/src/attention.jl b/src/attention.jl new file mode 100644 index 000000000..895f58da8 --- /dev/null +++ b/src/attention.jl @@ -0,0 +1,125 @@ +const AA3{T} = AbstractArray{T,3} +const AA4{T} = AbstractArray{T,4} +const AA{N,T} = AbstractArray{T,N} + +""" + dot_product_attention(query, key, value; [bias, droput_fn, mask, num_heads]) + +Multihead dot product attention used in transformer architectures. + +The input arrays must have the first two dimensions given by the number of features +and the sequece length, then an arbitrary number of batch dimensions or none. + +# Arguments + +- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. +- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. +- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. +- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`. + Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. +- `dropout_fn`: A dropout function to apply on the attention scores. Default `nothing`. +- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`. + Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. +- `num_heads`: Number of heads to split the input arrays into. Default `1`. + +# Examples + +```julia +q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) +y, α = dot_product_attention(q, k, v) +``` +""" +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...) where N + batch_size = size(q)[3:end] + + batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) + size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) + size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) + + q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) + + # Multihead attention. TODO create fastpath for singlehead attention. + q, k, v = split_heads.((q, k, v), num_heads) + x, α = _dot_product_attention(q, k, v; kws...) + x = join_heads(x) + + x = reshape(x, size(x, 1), size(x, 2), batch_size...) + α = reshape(α, size(α)[1:3]..., batch_size...) + return x, α +end + +function _dot_product_attention(q::AA4, k::AA4, v::AA4; + dropout_fn=nothing, bias=nothing, mask=nothing) + + α = dot_product_attention_scores(q, k; dropout_fn, bias, mask) + # [α] = [kv_len, q_len, num_heads, batch_size] + + # The following permutedims and batched_mul are equivalent to + # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + vt = permutedims(v, (1, 3, 2, 4)) + x = batched_mul(vt, α) + x = permutedims(x, (1, 3, 2, 4)) + # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] + return x, α +end + +""" + dot_product_attention_scores(query, key; [bias, droput_fn, mask]) + +Return the attention scores for the [`dot_product_attention`](@ref). + +Input arrays must have dimensions `(num_features ÷ num_heads, num_heads, sequence_length, batch_size)` + +""" +function dot_product_attention_scores(q::AA4{T}, k::AA4{T}; + dropout_fn=nothing, mask=nothing, bias=nothing) where T + + q = q ./ √T(size(q, 1)) + + # The following permutedims and batched_mul are equivalent to + # @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] + kt = permutedims(k, (3, 1, 2, 4)) + qt = permutedims(q, (1, 3, 2, 4)) + α = batched_mul(kt, qt) + # [α] = [kv_len, q_len, num_heads, batch_size] + + if bias !== nothing + α = α .+ bias + end + + if mask !== nothing + if mask === :causal + mask = make_causal_mask(α) + end + neginf = typemin(eltype(α)) + α = ifelse.(mask, α, neginf) + end + + α = softmax(α, dims=1) + return dropout_fn === nothing ? α : dropout_fn(α) +end + +""" + make_causal_mask(x) + +Return a boolean square matrix `m` of the same type as `x` and of side `size(x,2)`. +Its elements are set such that `m[i, j] == i ≤ j`. + +Can be used to mask the attention scores in [`dot_product_attention`](@ref). +""" +function make_causal_mask(x::AbstractArray) + len = size(x, 2) + mask = triu(trues_like(x, (len, len))) + return mask +end + +trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) +falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) + +split_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) +join_heads(x) = reshape(x, :, size(x)[3:end]...) + +@non_differentiable make_causal_mask(x) +@non_differentiable trues_like(::Any...) +@non_differentiable falses_like(::Any...) + diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 7e5e7fd72..7458f6fa2 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A) batched_mul(A, B) -> C A ⊠ B # \\boxtimes -Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. -If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. +Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent +any indices in the last dimensions. + +If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. To transpose each matrix, apply `batched_transpose` to the array, or `batched_adjoint` for conjugate-transpose: @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`. Both this `copy` and `batched_mul_generic!` produce `@debug` messages, and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. """ +function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + batch_size = size(x)[3:end] + @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." + x2 = reshape(x, size(x, 1), size(x, 2), :) + y2 = reshape(y, size(y, 1), size(y, 2), :) + z = batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), batch_size...) + end + function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != B")) diff --git a/src/gemm.jl b/src/gemm.jl index 91f88fc82..95c39d23f 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings end - C + return C end end end diff --git a/test/attention.jl b/test/attention.jl new file mode 100644 index 000000000..f14dfd5ea --- /dev/null +++ b/test/attention.jl @@ -0,0 +1,39 @@ +@testset "different batchsizes" begin + n = 15 + lenq = 3 + lenkv = 4 + for batch_size in [(), 1, 2, (2,1,3)], num_heads in [1, 3, 5] + q = rand(Float32, n, lenq, batch_size...) + k = rand(Float32, n, lenkv, batch_size...) + v = rand(Float32, n, lenkv, batch_size...) + y, α = dot_product_attention(q, k, v; num_heads) + @test y isa Array{Float32} + @test size(y) == (n, lenq, batch_size...) + @test size(α) == (lenkv, lenq, num_heads, batch_size...) + @test sum(α, dims=1) ≈ ones(1, lenq, num_heads, batch_size...) + end +end + +@testset "dot_product_attention_scores" begin + q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 + α = dot_product_attention_scores(q, k) + q2, k2 = reshape.((q, k), 8, 3, 1) + y, α2 = dot_product_attention(q2, k2, k2; num_heads=2) + @test α ≈ α2 +end + +@testset "specific results" begin + q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 + y, α = dot_product_attention(q, k, v; num_heads=2) + @test y ≈ [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;] + @test α ≈ [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;] +end + +@testset "mask" begin + q = rand(4, 2, 3, 1) + k = rand(4, 2, 5, 1) + mask = rand(Bool, (5, 3)) + α = dot_product_attention_scores(q, k; mask) + @test all((α[:,:,1,1].> 0) .== mask) + @test all((α[:,:,2,1].> 0) .== mask) +end From 9da000556675b95105e5722a62144dd4905d4a8b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 3 Jan 2023 10:48:46 +0100 Subject: [PATCH 02/15] run tests --- test/runtests.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 16084b4d2..c6c8333d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,10 @@ include("test_utils.jl") include("activations.jl") end + @testset "Attention" begin + include("activations.jl") + end + @testset "Batched Multiplication" begin include("batchedmul.jl") end From 2193639df79b8fa4a0cc8a097ee2f56cd482bb76 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 3 Jan 2023 10:50:33 +0100 Subject: [PATCH 03/15] docs --- docs/src/reference.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/src/reference.md b/docs/src/reference.md index a034c92c5..cb10efe59 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -33,6 +33,14 @@ tanhshrink trelu ``` +## Attention + +```@docs +dot_product_attention +dot_product_attention_scores +make_causal_mask +``` + ## Softmax `Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally. From eabcc02286d7561150a1458283a698aa5e5ea8ac Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 4 Jan 2023 10:48:15 +0100 Subject: [PATCH 04/15] address some review comments --- src/attention.jl | 51 +++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index 895f58da8..52917df88 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -3,7 +3,7 @@ const AA4{T} = AbstractArray{T,4} const AA{N,T} = AbstractArray{T,N} """ - dot_product_attention(query, key, value; [bias, droput_fn, mask, num_heads]) + dot_product_attention(query, key, value; [bias, fdrop, mask, nheads]) Multihead dot product attention used in transformer architectures. @@ -15,12 +15,11 @@ and the sequece length, then an arbitrary number of batch dimensions or none. - `query`: Query array of size `(qk_dim, q_len, batch_size...)`. - `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. - `value`: Value array of size `(v_dim, kv_len, batch_size...)`. -- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`. +- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. +- `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout). +- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. -- `dropout_fn`: A dropout function to apply on the attention scores. Default `nothing`. -- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`. - Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. -- `num_heads`: Number of heads to split the input arrays into. Default `1`. +- `nheads`: Number of heads to split the input arrays into. Default `1`. # Examples @@ -29,7 +28,7 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) y, α = dot_product_attention(q, k, v) ``` """ -function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...) where N +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) where N batch_size = size(q)[3:end] batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) @@ -39,7 +38,7 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws... q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) # Multihead attention. TODO create fastpath for singlehead attention. - q, k, v = split_heads.((q, k, v), num_heads) + q, k, v = split_heads.((q, k, v), nheads) x, α = _dot_product_attention(q, k, v; kws...) x = join_heads(x) @@ -49,17 +48,17 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws... end function _dot_product_attention(q::AA4, k::AA4, v::AA4; - dropout_fn=nothing, bias=nothing, mask=nothing) + fdrop=nothing, bias=nothing, mask=nothing) - α = dot_product_attention_scores(q, k; dropout_fn, bias, mask) - # [α] = [kv_len, q_len, num_heads, batch_size] + α = dot_product_attention_scores(q, k; fdrop, bias, mask) + # [α] = [kv_len, q_len, nheads, batch_size] # The following permutedims and batched_mul are equivalent to # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] vt = permutedims(v, (1, 3, 2, 4)) x = batched_mul(vt, α) x = permutedims(x, (1, 3, 2, 4)) - # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] + # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] return x, α end @@ -68,35 +67,33 @@ end Return the attention scores for the [`dot_product_attention`](@ref). -Input arrays must have dimensions `(num_features ÷ num_heads, num_heads, sequence_length, batch_size)` +Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`. """ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}; - dropout_fn=nothing, mask=nothing, bias=nothing) where T + fdrop=identity, mask=nothing, bias=nothing) where T - q = q ./ √T(size(q, 1)) - # The following permutedims and batched_mul are equivalent to - # @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] + # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) kt = permutedims(k, (3, 1, 2, 4)) - qt = permutedims(q, (1, 3, 2, 4)) - α = batched_mul(kt, qt) - # [α] = [kv_len, q_len, num_heads, batch_size] + qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) + logits = batched_mul(kt, qt) + # [logits] = [kv_len, q_len, nheads, batch_size] if bias !== nothing - α = α .+ bias + logits = logits .+ bias end if mask !== nothing if mask === :causal - mask = make_causal_mask(α) + mask = make_causal_mask(logits) end - neginf = typemin(eltype(α)) - α = ifelse.(mask, α, neginf) + neginf = typemin(eltype(logits)) + logits = ifelse.(mask, logits, neginf) end - α = softmax(α, dims=1) - return dropout_fn === nothing ? α : dropout_fn(α) + α = softmax(logits, dims=1) + return fdrop(α) end """ @@ -116,7 +113,7 @@ end trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) -split_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) +split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) join_heads(x) = reshape(x, :, size(x)[3:end]...) @non_differentiable make_causal_mask(x) From aac281d31c0cd43fa5dfd7b9080c3ad06fd4a8da Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 4 Jan 2023 11:26:55 +0100 Subject: [PATCH 05/15] fix tests --- src/attention.jl | 8 ++++---- test/attention.jl | 31 +++++++++++++++++++++++++------ test/runtests.jl | 2 +- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index 52917df88..bb7b196e4 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -97,15 +97,15 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}; end """ - make_causal_mask(x) + make_causal_mask(x, dims=2) -Return a boolean square matrix `m` of the same type as `x` and of side `size(x,2)`. +Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. Its elements are set such that `m[i, j] == i ≤ j`. Can be used to mask the attention scores in [`dot_product_attention`](@ref). """ -function make_causal_mask(x::AbstractArray) - len = size(x, 2) +function make_causal_mask(x::AbstractArray; dims::Int=2) + len = size(x, dims) mask = triu(trues_like(x, (len, len))) return mask end diff --git a/test/attention.jl b/test/attention.jl index f14dfd5ea..a6e57596e 100644 --- a/test/attention.jl +++ b/test/attention.jl @@ -2,15 +2,15 @@ n = 15 lenq = 3 lenkv = 4 - for batch_size in [(), 1, 2, (2,1,3)], num_heads in [1, 3, 5] + for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] q = rand(Float32, n, lenq, batch_size...) k = rand(Float32, n, lenkv, batch_size...) v = rand(Float32, n, lenkv, batch_size...) - y, α = dot_product_attention(q, k, v; num_heads) + y, α = dot_product_attention(q, k, v; nheads) @test y isa Array{Float32} @test size(y) == (n, lenq, batch_size...) - @test size(α) == (lenkv, lenq, num_heads, batch_size...) - @test sum(α, dims=1) ≈ ones(1, lenq, num_heads, batch_size...) + @test size(α) == (lenkv, lenq, nheads, batch_size...) + @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) end end @@ -18,13 +18,13 @@ end q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 α = dot_product_attention_scores(q, k) q2, k2 = reshape.((q, k), 8, 3, 1) - y, α2 = dot_product_attention(q2, k2, k2; num_heads=2) + y, α2 = dot_product_attention(q2, k2, k2; nheads=2) @test α ≈ α2 end @testset "specific results" begin q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 - y, α = dot_product_attention(q, k, v; num_heads=2) + y, α = dot_product_attention(q, k, v; nheads=2) @test y ≈ [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;] @test α ≈ [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;] end @@ -32,8 +32,27 @@ end @testset "mask" begin q = rand(4, 2, 3, 1) k = rand(4, 2, 5, 1) + mask = rand(Bool, (5, 3)) α = dot_product_attention_scores(q, k; mask) @test all((α[:,:,1,1].> 0) .== mask) @test all((α[:,:,2,1].> 0) .== mask) + + @testset "causal" begin + x = rand(4, 2, 3, 1) + mask = make_causal_mask(x, dims=3) + α = dot_product_attention_scores(x, x; mask) + @test all((α[:,:,1,1].> 0) .== mask) + @test all((α[:,:,2,1].> 0) .== mask) + + α2 = dot_product_attention_scores(x, x; mask=:causal) + @test α2 ≈ α + end +end + +@testset "dropout" begin + q = k = v = rand(10, 10, 10) + fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) + y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) + @test 0.6 > mean(>(0), α) > 0.4 end diff --git a/test/runtests.jl b/test/runtests.jl index c6c8333d2..e7987ef62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,7 +40,7 @@ include("test_utils.jl") end @testset "Attention" begin - include("activations.jl") + include("attention.jl") end @testset "Batched Multiplication" begin From 4d5a6d90b89a0921595917eeeab0348fbf54031e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 4 Jan 2023 13:13:36 +0100 Subject: [PATCH 06/15] fix fdrop --- src/attention.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index bb7b196e4..c20d68002 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -48,7 +48,7 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) w end function _dot_product_attention(q::AA4, k::AA4, v::AA4; - fdrop=nothing, bias=nothing, mask=nothing) + fdrop=identity, bias=nothing, mask=nothing) α = dot_product_attention_scores(q, k; fdrop, bias, mask) # [α] = [kv_len, q_len, nheads, batch_size] @@ -68,7 +68,6 @@ end Return the attention scores for the [`dot_product_attention`](@ref). Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`. - """ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}; fdrop=identity, mask=nothing, bias=nothing) where T From 5a5c58beb5372ee232184c0a3fa8a4a258202c19 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 4 Jan 2023 16:58:27 +0100 Subject: [PATCH 07/15] additional method --- src/attention.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index c20d68002..921c54d66 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -37,19 +37,26 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) w q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) - # Multihead attention. TODO create fastpath for singlehead attention. - q, k, v = split_heads.((q, k, v), nheads) - x, α = _dot_product_attention(q, k, v; kws...) - x = join_heads(x) + x, α = dot_product_attention(q, k, v; nheads, kws...) x = reshape(x, size(x, 1), size(x, 2), batch_size...) α = reshape(α, size(α)[1:3]..., batch_size...) return x, α end -function _dot_product_attention(q::AA4, k::AA4, v::AA4; - fdrop=identity, bias=nothing, mask=nothing) +function dot_product_attention(q::AA3, k::AA3, v::AA3; nheads=1, kws...) + # Multihead attention. TODO create fastpath for singlehead attention. + q, k, v = split_heads.((q, k, v), nheads) + x, α = _dot_product_attention(q, k, v; kws...) + return join_heads(x), α +end +function _dot_product_attention(q::AA4, k::AA4, v::AA4; + fdrop=identity, bias=nothing, mask=nothing) + # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] + # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] + # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] + α = dot_product_attention_scores(q, k; fdrop, bias, mask) # [α] = [kv_len, q_len, nheads, batch_size] From 19d377a39db78db41031a3413aa6c572b2716f3a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 01:41:49 +0100 Subject: [PATCH 08/15] bias is positional argument --- src/attention.jl | 50 +++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index 921c54d66..ea9fe7e09 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -3,21 +3,28 @@ const AA4{T} = AbstractArray{T,4} const AA{N,T} = AbstractArray{T,N} """ - dot_product_attention(query, key, value; [bias, fdrop, mask, nheads]) + dot_product_attention(query, key, value [bias]; fdrop, mask, nheads]) Multihead dot product attention used in transformer architectures. The input arrays must have the first two dimensions given by the number of features -and the sequece length, then an arbitrary number of batch dimensions or none. +and the sequece length, then an arbitrary number of batch dimensions or none. + +Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores. +of size `(kv_len, q_len, nheads, batch_size...)`. + +See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. # Arguments - `query`: Query array of size `(qk_dim, q_len, batch_size...)`. - `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. - `value`: Value array of size `(v_dim, kv_len, batch_size...)`. -- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. +- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + It will be added to the attention scores before applying the softmax. Default `nothing`. - `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout). -- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. +- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + The mask be applied to the attention scores before applying the softmax. Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. - `nheads`: Number of heads to split the input arrays into. Default `1`. @@ -28,36 +35,37 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) y, α = dot_product_attention(q, k, v) ``` """ -function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) where N +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N batch_size = size(q)[3:end] - batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) - size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) - size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) - q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) - x, α = dot_product_attention(q, k, v; nheads, kws...) + x, α = dot_product_attention(q, k, v, args...; kws...) x = reshape(x, size(x, 1), size(x, 2), batch_size...) α = reshape(α, size(α)[1:3]..., batch_size...) return x, α end -function dot_product_attention(q::AA3, k::AA3, v::AA3; nheads=1, kws...) +function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; + fdrop=identity, mask=nothing, nheads=1) + + (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) + size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) + size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) + # Multihead attention. TODO create fastpath for singlehead attention. q, k, v = split_heads.((q, k, v), nheads) - x, α = _dot_product_attention(q, k, v; kws...) + x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) return join_heads(x), α end -function _dot_product_attention(q::AA4, k::AA4, v::AA4; - fdrop=identity, bias=nothing, mask=nothing) +function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] - - α = dot_product_attention_scores(q, k; fdrop, bias, mask) + + α = dot_product_attention_scores(q, k, bias; fdrop, mask) # [α] = [kv_len, q_len, nheads, batch_size] # The following permutedims and batched_mul are equivalent to @@ -70,14 +78,16 @@ function _dot_product_attention(q::AA4, k::AA4, v::AA4; end """ - dot_product_attention_scores(query, key; [bias, droput_fn, mask]) + dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) Return the attention scores for the [`dot_product_attention`](@ref). +Input arrays must have dimensions +`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. -Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`. +See [`dot_product_attention`](@ref) for more details. """ -function dot_product_attention_scores(q::AA4{T}, k::AA4{T}; - fdrop=identity, mask=nothing, bias=nothing) where T +function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; + fdrop=identity, mask=nothing) where T # The following permutedims and batched_mul are equivalent to # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) From 10e99c753805571ac5e563dacddddc922fdae0b6 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 01:46:09 +0100 Subject: [PATCH 09/15] test bias --- test/attention.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/attention.jl b/test/attention.jl index a6e57596e..7c88a974c 100644 --- a/test/attention.jl +++ b/test/attention.jl @@ -56,3 +56,12 @@ end y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) @test 0.6 > mean(>(0), α) > 0.4 end + +@testset "bias" begin + q = rand(4, 5, 1) + k = v = rand(4, 3, 1) + bias = randn(3, 5) + y, α = dot_product_attention(q, k, v, bias; nheads=2) + @test size(α) == (3, 5, 2, 1) + @test size(y) == (4, 5, 1) +end From e61909cd3f29b6457d202d9ad19dfc5695626f10 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 07:33:25 +0100 Subject: [PATCH 10/15] fix tests on julia 1.6 --- test/attention.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/attention.jl b/test/attention.jl index 7c88a974c..b505ae7a3 100644 --- a/test/attention.jl +++ b/test/attention.jl @@ -25,8 +25,12 @@ end @testset "specific results" begin q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 y, α = dot_product_attention(q, k, v; nheads=2) - @test y ≈ [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;] - @test α ≈ [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;] + ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468] + ytrue = reshape(ytrue, 4, 3, 1) + αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966] + αtrue = reshape(αtrue, 3, 3, 2, 1) + @test y ≈ ytrue + @test α ≈ αtrue end @testset "mask" begin From 43632eee806bf0a638f909a3ce350fe4750ce0ef Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 07:46:34 +0100 Subject: [PATCH 11/15] typos --- src/attention.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index ea9fe7e09..19381953b 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -3,7 +3,7 @@ const AA4{T} = AbstractArray{T,4} const AA{N,T} = AbstractArray{T,N} """ - dot_product_attention(query, key, value [bias]; fdrop, mask, nheads]) + dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) Multihead dot product attention used in transformer architectures. @@ -24,7 +24,7 @@ See also [`dot_product_attention_scores`](@ref) if you only need the attention s It will be added to the attention scores before applying the softmax. Default `nothing`. - `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout). - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. - The mask be applied to the attention scores before applying the softmax. + The mask is applied to the attention scores before the softmax. Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. - `nheads`: Number of heads to split the input arrays into. Default `1`. From 958171b05ad08c636ee18ec5fcb9ae9c73810788 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 07:48:09 +0100 Subject: [PATCH 12/15] improve docs --- src/attention.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/attention.jl b/src/attention.jl index 19381953b..0f8adf9ab 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -22,7 +22,8 @@ See also [`dot_product_attention_scores`](@ref) if you only need the attention s - `value`: Value array of size `(v_dim, kv_len, batch_size...)`. - `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. It will be added to the attention scores before applying the softmax. Default `nothing`. -- `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout). +- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. + Default `identity` (no dropout). - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. The mask is applied to the attention scores before the softmax. Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. From df8aa9bc820a39b4aa91027863436d6422e79681 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 23 Jan 2023 00:35:57 +0100 Subject: [PATCH 13/15] remove :causal --- src/attention.jl | 12 +++++------- test/attention.jl | 11 ++++++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index 0f8adf9ab..8f98bd096 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -8,9 +8,10 @@ const AA{N,T} = AbstractArray{T,N} Multihead dot product attention used in transformer architectures. The input arrays must have the first two dimensions given by the number of features -and the sequece length, then an arbitrary number of batch dimensions or none. +and the sequence length, then an arbitrary number of batch dimensions or none. -Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores. + +Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores of size `(kv_len, q_len, nheads, batch_size...)`. See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. @@ -25,8 +26,8 @@ See also [`dot_product_attention_scores`](@ref) if you only need the attention s - `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. Default `identity` (no dropout). - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. - The mask is applied to the attention scores before the softmax. - Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. + The mask is applied to the attention scores just before the softmax. + See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. - `nheads`: Number of heads to split the input arrays into. Default `1`. # Examples @@ -102,9 +103,6 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; end if mask !== nothing - if mask === :causal - mask = make_causal_mask(logits) - end neginf = typemin(eltype(logits)) logits = ifelse.(mask, logits, neginf) end diff --git a/test/attention.jl b/test/attention.jl index b505ae7a3..3f7725dd0 100644 --- a/test/attention.jl +++ b/test/attention.jl @@ -48,9 +48,6 @@ end α = dot_product_attention_scores(x, x; mask) @test all((α[:,:,1,1].> 0) .== mask) @test all((α[:,:,2,1].> 0) .== mask) - - α2 = dot_product_attention_scores(x, x; mask=:causal) - @test α2 ≈ α end end @@ -69,3 +66,11 @@ end @test size(α) == (3, 5, 2, 1) @test size(y) == (4, 5, 1) end + +@testset "gradient" begin + q = rand(4, 5, 1) + k = v = rand(4, 3, 1) + bias = randn(3, 5) + y, α = dot_product_attention(q, k, v, bias; nheads=2) + gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias) +end From 09ac33b0d31e55b0d02c2a52a8e60c4788b8e3da Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 24 Jan 2023 04:45:52 +0100 Subject: [PATCH 14/15] Update src/attention.jl --- src/attention.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/attention.jl b/src/attention.jl index 8f98bd096..9ec07ea9b 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -131,7 +131,7 @@ falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) join_heads(x) = reshape(x, :, size(x)[3:end]...) -@non_differentiable make_causal_mask(x) +@non_differentiable make_causal_mask(::Any...) @non_differentiable trues_like(::Any...) @non_differentiable falses_like(::Any...) From d17de5e69197d5023a0eb544a16a84f7b6d81dbf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 3 Feb 2023 07:59:10 +0100 Subject: [PATCH 15/15] add function barrier --- src/attention.jl | 25 ++++++++++++++++--------- test/attention.jl | 8 ++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/attention.jl b/src/attention.jl index 9ec07ea9b..fb11e82d0 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -98,19 +98,26 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; logits = batched_mul(kt, qt) # [logits] = [kv_len, q_len, nheads, batch_size] - if bias !== nothing - logits = logits .+ bias - end - - if mask !== nothing - neginf = typemin(eltype(logits)) - logits = ifelse.(mask, logits, neginf) - end - + logits = apply_attn_bias(logits, bias) + logits = apply_attn_mask(logits, mask) + α = softmax(logits, dims=1) return fdrop(α) end +apply_attn_bias(logits, bias::Nothing) = logits + +apply_attn_bias(logits, bias) = logits .+ bias + + +apply_attn_mask(logits, mask::Nothing) = logits + +function apply_attn_mask(logits, mask) + neginf = typemin(eltype(logits)) + ifelse.(mask, logits, neginf) +end + + """ make_causal_mask(x, dims=2) diff --git a/test/attention.jl b/test/attention.jl index 3f7725dd0..b21088330 100644 --- a/test/attention.jl +++ b/test/attention.jl @@ -25,12 +25,12 @@ end @testset "specific results" begin q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 y, α = dot_product_attention(q, k, v; nheads=2) - ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468] + ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788] ytrue = reshape(ytrue, 4, 3, 1) - αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966] + αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801] αtrue = reshape(αtrue, 3, 3, 2, 1) - @test y ≈ ytrue - @test α ≈ αtrue + @test y ≈ ytrue atol=1e-5 + @test α ≈ αtrue atol=1e-5 end @testset "mask" begin