|
| 1 | +using Flux, Test, LinearAlgebra, Random, Statistics |
| 2 | +using CUDA, CUDAKernels, LoopVectorization |
| 3 | +using Tullio |
| 4 | +using NeuralAttentionlib |
| 5 | +using BenchmarkTools |
| 6 | + |
| 7 | +const A3{T} = AbstractArray{T, 3} |
| 8 | + |
1 | 9 | """
|
2 |
| - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, |
3 |
| - attn_dropout_prob = 0., proj_dropout_prob = 0.) |
| 10 | + MultiHeadAttention(dims, num_heads; |
| 11 | + [bias, init, attn_dropout_prob, proj_dropout_prob]) |
4 | 12 |
|
5 |
| -Multi-head self-attention layer. |
| 13 | +Multi-head dot-product attention layer. |
6 | 14 |
|
7 | 15 | # Arguments
|
8 | 16 |
|
9 |
| -- `planes`: number of input channels |
| 17 | +- `dims`: ... |
10 | 18 | - `nheads`: number of heads
|
11 |
| -- `qkv_bias`: whether to use bias in the layer to get the query, key and value |
| 19 | +- `init`: weight initializer for the Dense layers. |
| 20 | +- `bias` : whether pointwise QKVO dense transforms use bias. |
12 | 21 | - `attn_dropout_prob`: dropout probability after the self-attention layer
|
13 | 22 | - `proj_dropout_prob`: dropout probability after the projection layer
|
| 23 | +
|
| 24 | +# Forward |
| 25 | +
|
| 26 | +- `in_q`: input tensor of shape `(batch_size, seq_len, dims) |
| 27 | +- `in_k`: input tensor of shape `(batch_size, seq_len, dims) |
| 28 | +- `in_v`: input tensor of shape `(batch_size, seq_len, dims) |
| 29 | +- `mask`: input tensor of shape `(batch_size, seq_len, seq_len)` |
| 30 | +- `return_weights`: whether to return the attention weights |
| 31 | +
|
| 32 | +# Examples |
| 33 | +
|
| 34 | +```julia |
| 35 | +mha = MultiHeadAttention(64, 8) |
| 36 | +``` |
14 | 37 | """
|
15 |
| -struct MultiHeadAttention{P, Q, R} |
16 |
| - nheads::Int |
17 |
| - qkv_layer::P |
18 |
| - attn_drop::Q |
19 |
| - projection::R |
20 |
| -end |
21 |
| - |
22 |
| -@functor MHAttention |
23 |
| - |
24 |
| -function MultiHeadAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, |
25 |
| - attn_dropout_prob = 0.0, proj_dropout_prob = 0.0) |
26 |
| - @assert planes % nheads==0 "planes should be divisible by nheads" |
27 |
| - qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) |
28 |
| - attn_drop = Dropout(attn_dropout_prob) |
29 |
| - proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob)) |
30 |
| - return MultiHeadAttention(nheads, qkv_layer, attn_drop, proj) |
31 |
| -end |
32 |
| - |
33 |
| -function (m::MultiHeadAttention)(x::AbstractArray{T, 3}) where {T} |
34 |
| - nfeatures, seq_len, batch_size = size(x) |
35 |
| - x_reshaped = reshape(x, nfeatures, seq_len * batch_size) |
36 |
| - qkv = m.qkv_layer(x_reshaped) |
37 |
| - qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size) |
38 |
| - query, key, value = chunk(qkv_reshaped, 3; dims = 4) |
39 |
| - scale = convert(T, sqrt(size(query, 1) / m.nheads)) |
40 |
| - key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads, |
41 |
| - seq_len * batch_size) |
42 |
| - query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, |
43 |
| - m.nheads, seq_len * batch_size) |
44 |
| - |
45 |
| - attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale) |
46 |
| - attention = m.attn_drop(attention) |
| 38 | +struct MultiHeadAttention |
| 39 | + num_heads::Int |
| 40 | + qkv_proj |
| 41 | + attn_drop |
| 42 | + out_proj |
| 43 | +end |
| 44 | + |
| 45 | +@functor MultiHeadAttention |
| 46 | + |
| 47 | +function MultiHeadAttention(dims, num_heads::Int; |
| 48 | + bias::Bool = false, |
| 49 | + # init = glorot_uniform, # TODO |
| 50 | + attn_dropout_prob = 0.0, |
| 51 | + out_proj_dropout_prob = 0.0) |
| 52 | + |
| 53 | + dims = mha_process_dims(dims) |
| 54 | + @assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads" |
| 55 | + qkv_proj = QKVProj((dims.q_in, dims.k_in, dims.v_in) => dims.qkv; bias) |
| 56 | + attn_drop = Dropout(attn_dropout_prob) |
| 57 | + out_proj = Chain(Dense(dims.qkv => dims.out; bias), Dropout(out_proj_dropout_prob)) |
| 58 | + return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) |
| 59 | +end |
| 60 | + |
| 61 | +mha_process_dims(dims::Int) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims) |
| 62 | +mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair}) = (; q_in = in, k_in = in, v_in = in, qkv, out) |
| 63 | +mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair}) = (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out) |
| 64 | + |
| 65 | +# self-attention |
| 66 | +(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...) |
| 67 | + |
| 68 | +function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, v=:tullio) |
| 69 | + ## [q_in] = [q_in_dim, q_len, batch_size] |
| 70 | + ## [k_in] = [k_in_dim, kv_len, batch_size] |
| 71 | + ## [v_in] = [v_in_dim, kv_len, batch_size] |
| 72 | + |
| 73 | + if v == :tullio |
| 74 | + q, k, v = m.qkv_proj(q_in, k_in, v_in, m.num_heads) |
| 75 | + # [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size] |
| 76 | + # [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size] |
47 | 77 |
|
48 |
| - value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, |
49 |
| - m.nheads, seq_len * batch_size) |
50 |
| - pre_projection = reshape(batched_mul(attention, value_reshaped), |
51 |
| - (nfeatures, seq_len, batch_size)) |
52 |
| - y = m.projection(reshape(pre_projection, size(pre_projection, 1), :)) |
53 |
| - return reshape(y, :, seq_len, batch_size) |
| 78 | + x, α = dot_product_attention(q, k, v; dropout=m.attn_drop) |
| 79 | + x = reshape(x, :, size(x, 3), size(x, 4)) |
| 80 | + elseif v == :nnalib |
| 81 | + q, k, v = m.qkv_proj(q_in, k_in, v_in) |
| 82 | + x = NeuralAttentionlib.multihead_qkv_attention(m.num_heads, q, k, v) |
| 83 | + else |
| 84 | + error("Unknown attention implementation") |
| 85 | + end |
| 86 | + |
| 87 | + x = m.out_proj(x) |
| 88 | + |
| 89 | + return x |
| 90 | + # return with_weights ? (x, α) : x |
54 | 91 | end
|
55 | 92 |
|
56 |
| -using Flux, Functors, Test, NNlib, MLUtils |
| 93 | +# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention |
| 94 | +function dot_product_attention(q, k, v; dropout=nothing) |
| 95 | + α = dot_product_attention_weights(q, k; dropout) |
| 96 | + # [α] = [kv_len, q_len, num_heads, batch_size] |
| 97 | + @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] |
| 98 | + # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] |
| 99 | + |
| 100 | + return x, α |
| 101 | +end |
57 | 102 |
|
58 |
| -mha = MultiHeadAttention(64, 8) |
59 |
| -sz = (64, 100, 32) |
60 |
| -x = rand(Float32, sz) |
61 |
| -y = mha(x) |
62 |
| -@test y isa Array{Float32, 3} |
63 |
| -@test size(y) == sz |
| 103 | +function dot_product_attention_weights(q, k; dropout=nothing) |
| 104 | + @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] |
| 105 | + # [α] = [kv_len, q_len, num_heads, batch_size] |
| 106 | + α = softmax(α, dims=1) |
| 107 | + return dropout === nothing ? α : dropout(α) |
| 108 | +end |
| 109 | + |
| 110 | + |
| 111 | +struct QKVProj |
| 112 | + k_proj::Dense |
| 113 | + v_proj::Dense |
| 114 | + q_proj::Dense |
| 115 | +end |
| 116 | + |
| 117 | +@functor QKVProj |
| 118 | + |
| 119 | +function QKVProj((in_dim, qkv_dim)::Pair; bias = false) |
| 120 | + q_in_dim, k_in_dim, v_in_dim = in_dim |
| 121 | + return QKVProj( |
| 122 | + Dense(k_in_dim => qkv_dim; bias), |
| 123 | + Dense(v_in_dim => qkv_dim; bias), |
| 124 | + Dense(q_in_dim => qkv_dim; bias) |
| 125 | + ) |
| 126 | +end |
| 127 | + |
| 128 | +function (proj::QKVProj)(q_in, k_in, v_in, num_heads) |
| 129 | + q = proj.q_proj(q_in) |
| 130 | + sz = size(q) |
| 131 | + newsz = (sz[1] ÷ num_heads, num_heads, sz[2:end]...) |
| 132 | + q = reshape(q, newsz) |
| 133 | + k = reshape(proj.k_proj(k_in), newsz) |
| 134 | + v = reshape(proj.v_proj(v_in), newsz) |
| 135 | + return q, k, v |
| 136 | +end |
| 137 | + |
| 138 | +function (proj::QKVProj)(q_in, k_in, v_in) |
| 139 | + return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) |
| 140 | +end |
| 141 | + |
| 142 | + |
| 143 | +function perf(dim, len, batch_size, num_heads) |
| 144 | + mha = MultiHeadAttention(dim, num_heads) |
| 145 | + x = rand(Float32, (dim, len, batch_size)) |
| 146 | + |
| 147 | + y = mha(x, x, x) |
| 148 | + @test y isa Array{Float32, 3} |
| 149 | + @test size(y) == (dim, len, batch_size) |
| 150 | + |
| 151 | + |
| 152 | + println("tullio") |
| 153 | + @btime $mha($x, v=:tullio); |
| 154 | + @btime gradient(m -> sum(m($x, v=:tullio)), $mha); |
| 155 | + |
| 156 | + println("nnalib") |
| 157 | + @btime $mha($x, $x, $x, v=:nnalib); |
| 158 | + @btime gradient(m -> sum(m($x, v=:nnalib)), $mha); |
| 159 | + |
| 160 | + if CUDA.functional() |
| 161 | + mha_gpu = mha |> gpu |
| 162 | + x_gpu = x |> gpu |
| 163 | + |
| 164 | + println("tullio - gpu") |
| 165 | + @btime $mha_gpu($x_gpu, v=:tullio); |
| 166 | + @btime gradient(m -> sum(m($x_gpu, v=:tullio)), $mha_gpu); |
| 167 | + |
| 168 | + println("nnalib - gpu") |
| 169 | + @btime CUDA.@sync $mha_gpu($x_gpu, v=:nnalib); |
| 170 | + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, v=:nnalib)), $mha_gpu); |
| 171 | + end |
| 172 | + return nothing |
| 173 | +end |
| 174 | + |
| 175 | +perf(64, 100, 32, 8) |
0 commit comments