- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 129
implement dot_product_attention #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
bf64ca8
              9da0005
              2193639
              eabcc02
              aac281d
              4d5a6d9
              5a5c58b
              19d377a
              10e99c7
              e61909c
              43632ee
              958171b
              df8aa9b
              09ac33b
              d17de5e
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| const AA3{T} = AbstractArray{T,3} | ||
| const AA4{T} = AbstractArray{T,4} | ||
| const AA{N,T} = AbstractArray{T,N} | ||
|  | ||
| """ | ||
| dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) | ||
| Multihead dot product attention used in transformer architectures. | ||
| The input arrays must have the first two dimensions given by the number of features | ||
| and the sequence length, then an arbitrary number of batch dimensions or none. | ||
| Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores | ||
| of size `(kv_len, q_len, nheads, batch_size...)`. | ||
| See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. | ||
| # Arguments | ||
| - `query`: Query array of size `(qk_dim, q_len, batch_size...)`. | ||
| - `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. | ||
| - `value`: Value array of size `(v_dim, kv_len, batch_size...)`. | ||
| - `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. | ||
| It will be added to the attention scores before applying the softmax. Default `nothing`. | ||
| - `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. | ||
| Default `identity` (no dropout). | ||
| - `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. | ||
| The mask is applied to the attention scores just before the softmax. | ||
| See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. | ||
| - `nheads`: Number of heads to split the input arrays into. Default `1`. | ||
| # Examples | ||
| ```julia | ||
| q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) | ||
| y, α = dot_product_attention(q, k, v) | ||
| ``` | ||
| """ | ||
| function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N | ||
| batch_size = size(q)[3:end] | ||
| batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) | ||
| q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) | ||
|  | ||
| x, α = dot_product_attention(q, k, v, args...; kws...) | ||
|  | ||
| x = reshape(x, size(x, 1), size(x, 2), batch_size...) | ||
| α = reshape(α, size(α)[1:3]..., batch_size...) | ||
| return x, α | ||
| end | ||
|  | ||
| function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; | ||
| fdrop=identity, mask=nothing, nheads=1) | ||
|  | ||
| (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) | ||
| size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) | ||
| size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) | ||
|  | ||
| # Multihead attention. TODO create fastpath for singlehead attention. | ||
| q, k, v = split_heads.((q, k, v), nheads) | ||
| x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) | ||
| return join_heads(x), α | ||
| end | ||
|  | ||
| function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) | ||
| # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] | ||
| # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] | ||
| # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] | ||
|  | ||
| α = dot_product_attention_scores(q, k, bias; fdrop, mask) | ||
| # [α] = [kv_len, q_len, nheads, batch_size] | ||
|  | ||
| # The following permutedims and batched_mul are equivalent to | ||
| # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] | ||
| vt = permutedims(v, (1, 3, 2, 4)) | ||
| x = batched_mul(vt, α) | ||
| x = permutedims(x, (1, 3, 2, 4)) | ||
| # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] | ||
| return x, α | ||
| end | ||
|  | ||
| """ | ||
| dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) | ||
| Return the attention scores for the [`dot_product_attention`](@ref). | ||
| Input arrays must have dimensions | ||
| `(num_features ÷ nheads, nheads, sequence_length, batch_size)`. | ||
| See [`dot_product_attention`](@ref) for more details. | ||
| """ | ||
| function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; | ||
| fdrop=identity, mask=nothing) where T | ||
|  | ||
| # The following permutedims and batched_mul are equivalent to | ||
| # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) | ||
| kt = permutedims(k, (3, 1, 2, 4)) | ||
| qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) | ||
| logits = batched_mul(kt, qt) | ||
| # [logits] = [kv_len, q_len, nheads, batch_size] | ||
|  | ||
| if bias !== nothing | ||
| logits = logits .+ bias | ||
| end | ||
|  | ||
| if mask !== nothing | ||
| neginf = typemin(eltype(logits)) | ||
| logits = ifelse.(mask, logits, neginf) | ||
| end | ||
|          | ||
|  | ||
| α = softmax(logits, dims=1) | ||
| return fdrop(α) | ||
| end | ||
|  | ||
| """ | ||
| make_causal_mask(x, dims=2) | ||
| Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. | ||
| Its elements are set such that `m[i, j] == i ≤ j`. | ||
| Can be used to mask the attention scores in [`dot_product_attention`](@ref). | ||
| """ | ||
| function make_causal_mask(x::AbstractArray; dims::Int=2) | ||
| len = size(x, dims) | ||
| mask = triu(trues_like(x, (len, len))) | ||
| return mask | ||
| end | ||
|  | ||
| trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) | ||
| falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) | ||
|  | ||
| split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) | ||
| join_heads(x) = reshape(x, :, size(x)[3:end]...) | ||
|  | ||
| @non_differentiable make_causal_mask(x) | ||
|         
                  CarloLucibello marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| @non_differentiable trues_like(::Any...) | ||
| @non_differentiable falses_like(::Any...) | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A) | |
| batched_mul(A, B) -> C | ||
| A ⊠ B # \\boxtimes | ||
|  | ||
| Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. | ||
| If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. | ||
| Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent | ||
| any indices in the last dimensions. | ||
|  | ||
| If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. | ||
|  | ||
| To transpose each matrix, apply `batched_transpose` to the array, | ||
| or `batched_adjoint` for conjugate-transpose: | ||
|  | @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`. | |
| Both this `copy` and `batched_mul_generic!` produce `@debug` messages, | ||
| and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. | ||
| """ | ||
| function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My vote is to make this an internal  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method? | ||
| batch_size = size(x)[3:end] | ||
| @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." | ||
| x2 = reshape(x, size(x, 1), size(x, 2), :) | ||
| y2 = reshape(y, size(y, 1), size(y, 2), :) | ||
| z = batched_mul(x2, y2) | ||
| return reshape(z, size(z, 1), size(z, 2), batch_size...) | ||
| end | ||
|  | ||
| function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} | ||
| size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || | ||
| throw(DimensionMismatch("batch size mismatch: A != B")) | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings | |
|  | ||
| end | ||
|  | ||
| C | ||
| return C | ||
| end | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| @testset "different batchsizes" begin | ||
| n = 15 | ||
| lenq = 3 | ||
| lenkv = 4 | ||
| for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] | ||
| q = rand(Float32, n, lenq, batch_size...) | ||
| k = rand(Float32, n, lenkv, batch_size...) | ||
| v = rand(Float32, n, lenkv, batch_size...) | ||
| y, α = dot_product_attention(q, k, v; nheads) | ||
| @test y isa Array{Float32} | ||
| @test size(y) == (n, lenq, batch_size...) | ||
| @test size(α) == (lenkv, lenq, nheads, batch_size...) | ||
| @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) | ||
| end | ||
| end | ||
|  | ||
| @testset "dot_product_attention_scores" begin | ||
| q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 | ||
| α = dot_product_attention_scores(q, k) | ||
| q2, k2 = reshape.((q, k), 8, 3, 1) | ||
| y, α2 = dot_product_attention(q2, k2, k2; nheads=2) | ||
| @test α ≈ α2 | ||
| end | ||
|  | ||
| @testset "specific results" begin | ||
| q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 | ||
| y, α = dot_product_attention(q, k, v; nheads=2) | ||
| ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468] | ||
| ytrue = reshape(ytrue, 4, 3, 1) | ||
| αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966] | ||
|         
                  CarloLucibello marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| αtrue = reshape(αtrue, 3, 3, 2, 1) | ||
| @test y ≈ ytrue | ||
| @test α ≈ αtrue | ||
| end | ||
|  | ||
| @testset "mask" begin | ||
| q = rand(4, 2, 3, 1) | ||
| k = rand(4, 2, 5, 1) | ||
|  | ||
| mask = rand(Bool, (5, 3)) | ||
| α = dot_product_attention_scores(q, k; mask) | ||
| @test all((α[:,:,1,1].> 0) .== mask) | ||
| @test all((α[:,:,2,1].> 0) .== mask) | ||
|  | ||
| @testset "causal" begin | ||
| x = rand(4, 2, 3, 1) | ||
| mask = make_causal_mask(x, dims=3) | ||
| α = dot_product_attention_scores(x, x; mask) | ||
| @test all((α[:,:,1,1].> 0) .== mask) | ||
| @test all((α[:,:,2,1].> 0) .== mask) | ||
| end | ||
| end | ||
|  | ||
| @testset "dropout" begin | ||
| q = k = v = rand(10, 10, 10) | ||
| fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) | ||
| y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) | ||
| @test 0.6 > mean(>(0), α) > 0.4 | ||
| end | ||
|  | ||
| @testset "bias" begin | ||
| q = rand(4, 5, 1) | ||
| k = v = rand(4, 3, 1) | ||
| bias = randn(3, 5) | ||
| y, α = dot_product_attention(q, k, v, bias; nheads=2) | ||
| @test size(α) == (3, 5, 2, 1) | ||
| @test size(y) == (4, 5, 1) | ||
| end | ||
|  | ||
| @testset "gradient" begin | ||
| q = rand(4, 5, 1) | ||
| k = v = rand(4, 3, 1) | ||
| bias = randn(3, 5) | ||
| y, α = dot_product_attention(q, k, v, bias; nheads=2) | ||
| gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias) | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.