|
| 1 | +const AA3{T} = AbstractArray{T,3} |
| 2 | +const AA4{T} = AbstractArray{T,4} |
| 3 | +const AA{N,T} = AbstractArray{T,N} |
| 4 | + |
| 5 | +""" |
| 6 | + dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) |
| 7 | +
|
| 8 | +Multihead dot product attention used in transformer architectures. |
| 9 | +
|
| 10 | +The input arrays must have the first two dimensions given by the number of features |
| 11 | +and the sequence length, then an arbitrary number of batch dimensions or none. |
| 12 | +
|
| 13 | +
|
| 14 | +Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores |
| 15 | +of size `(kv_len, q_len, nheads, batch_size...)`. |
| 16 | +
|
| 17 | +See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. |
| 18 | +
|
| 19 | +# Arguments |
| 20 | +
|
| 21 | +- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. |
| 22 | +- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. |
| 23 | +- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. |
| 24 | +- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. |
| 25 | + It will be added to the attention scores before applying the softmax. Default `nothing`. |
| 26 | +- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. |
| 27 | + Default `identity` (no dropout). |
| 28 | +- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. |
| 29 | + The mask is applied to the attention scores just before the softmax. |
| 30 | + See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. |
| 31 | +- `nheads`: Number of heads to split the input arrays into. Default `1`. |
| 32 | +
|
| 33 | +# Examples |
| 34 | + |
| 35 | +```julia |
| 36 | +q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) |
| 37 | +y, α = dot_product_attention(q, k, v) |
| 38 | +``` |
| 39 | +""" |
| 40 | +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N |
| 41 | + batch_size = size(q)[3:end] |
| 42 | + batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) |
| 43 | + q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) |
| 44 | + |
| 45 | + x, α = dot_product_attention(q, k, v, args...; kws...) |
| 46 | + |
| 47 | + x = reshape(x, size(x, 1), size(x, 2), batch_size...) |
| 48 | + α = reshape(α, size(α)[1:3]..., batch_size...) |
| 49 | + return x, α |
| 50 | +end |
| 51 | + |
| 52 | +function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; |
| 53 | + fdrop=identity, mask=nothing, nheads=1) |
| 54 | + |
| 55 | + (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) |
| 56 | + size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) |
| 57 | + size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) |
| 58 | + |
| 59 | + # Multihead attention. TODO create fastpath for singlehead attention. |
| 60 | + q, k, v = split_heads.((q, k, v), nheads) |
| 61 | + x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) |
| 62 | + return join_heads(x), α |
| 63 | +end |
| 64 | + |
| 65 | +function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) |
| 66 | + # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] |
| 67 | + # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] |
| 68 | + # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] |
| 69 | + |
| 70 | + α = dot_product_attention_scores(q, k, bias; fdrop, mask) |
| 71 | + # [α] = [kv_len, q_len, nheads, batch_size] |
| 72 | + |
| 73 | + # The following permutedims and batched_mul are equivalent to |
| 74 | + # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] |
| 75 | + vt = permutedims(v, (1, 3, 2, 4)) |
| 76 | + x = batched_mul(vt, α) |
| 77 | + x = permutedims(x, (1, 3, 2, 4)) |
| 78 | + # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] |
| 79 | + return x, α |
| 80 | +end |
| 81 | + |
| 82 | +""" |
| 83 | + dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) |
| 84 | +
|
| 85 | +Return the attention scores for the [`dot_product_attention`](@ref). |
| 86 | +Input arrays must have dimensions |
| 87 | +`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. |
| 88 | +
|
| 89 | +See [`dot_product_attention`](@ref) for more details. |
| 90 | +""" |
| 91 | +function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; |
| 92 | + fdrop=identity, mask=nothing) where T |
| 93 | + |
| 94 | + # The following permutedims and batched_mul are equivalent to |
| 95 | + # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) |
| 96 | + kt = permutedims(k, (3, 1, 2, 4)) |
| 97 | + qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) |
| 98 | + logits = batched_mul(kt, qt) |
| 99 | + # [logits] = [kv_len, q_len, nheads, batch_size] |
| 100 | + |
| 101 | + logits = apply_attn_bias(logits, bias) |
| 102 | + logits = apply_attn_mask(logits, mask) |
| 103 | + |
| 104 | + α = softmax(logits, dims=1) |
| 105 | + return fdrop(α) |
| 106 | +end |
| 107 | + |
| 108 | +apply_attn_bias(logits, bias::Nothing) = logits |
| 109 | + |
| 110 | +apply_attn_bias(logits, bias) = logits .+ bias |
| 111 | + |
| 112 | + |
| 113 | +apply_attn_mask(logits, mask::Nothing) = logits |
| 114 | + |
| 115 | +function apply_attn_mask(logits, mask) |
| 116 | + neginf = typemin(eltype(logits)) |
| 117 | + ifelse.(mask, logits, neginf) |
| 118 | +end |
| 119 | + |
| 120 | + |
| 121 | +""" |
| 122 | + make_causal_mask(x, dims=2) |
| 123 | +
|
| 124 | +Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. |
| 125 | +Its elements are set such that `m[i, j] == i ≤ j`. |
| 126 | +
|
| 127 | +Can be used to mask the attention scores in [`dot_product_attention`](@ref). |
| 128 | +""" |
| 129 | +function make_causal_mask(x::AbstractArray; dims::Int=2) |
| 130 | + len = size(x, dims) |
| 131 | + mask = triu(trues_like(x, (len, len))) |
| 132 | + return mask |
| 133 | +end |
| 134 | + |
| 135 | +trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) |
| 136 | +falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) |
| 137 | + |
| 138 | +split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) |
| 139 | +join_heads(x) = reshape(x, :, size(x)[3:end]...) |
| 140 | + |
| 141 | +@non_differentiable make_causal_mask(::Any...) |
| 142 | +@non_differentiable trues_like(::Any...) |
| 143 | +@non_differentiable falses_like(::Any...) |
| 144 | + |
0 commit comments