Skip to content

Commit 8ac9e6b

Browse files
generic attention
1 parent 0ec9a00 commit 8ac9e6b

File tree

2 files changed

+166
-50
lines changed

2 files changed

+166
-50
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ version = "0.13.10"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1113
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1214
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1315
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1416
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
17+
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
1518
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1619
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1720
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
@@ -21,6 +24,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2124
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2225
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2326
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
27+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
2428
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2529

2630
[compat]

src/layers/attention.jl

Lines changed: 162 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,175 @@
1+
using Flux, Test, LinearAlgebra, Random, Statistics
2+
using CUDA, CUDAKernels, LoopVectorization
3+
using Tullio
4+
using NeuralAttentionlib
5+
using BenchmarkTools
6+
7+
const A3{T} = AbstractArray{T, 3}
8+
19
"""
2-
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
3-
attn_dropout_prob = 0., proj_dropout_prob = 0.)
10+
MultiHeadAttention(dims, num_heads;
11+
[bias, init, attn_dropout_prob, proj_dropout_prob])
412
5-
Multi-head self-attention layer.
13+
Multi-head dot-product attention layer.
614
715
# Arguments
816
9-
- `planes`: number of input channels
17+
- `dims`: ...
1018
- `nheads`: number of heads
11-
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
19+
- `init`: weight initializer for the Dense layers.
20+
- `bias` : whether pointwise QKVO dense transforms use bias.
1221
- `attn_dropout_prob`: dropout probability after the self-attention layer
1322
- `proj_dropout_prob`: dropout probability after the projection layer
23+
24+
# Forward
25+
26+
- `in_q`: input tensor of shape `(batch_size, seq_len, dims)
27+
- `in_k`: input tensor of shape `(batch_size, seq_len, dims)
28+
- `in_v`: input tensor of shape `(batch_size, seq_len, dims)
29+
- `mask`: input tensor of shape `(batch_size, seq_len, seq_len)`
30+
- `return_weights`: whether to return the attention weights
31+
32+
# Examples
33+
34+
```julia
35+
mha = MultiHeadAttention(64, 8)
36+
```
1437
"""
15-
struct MultiHeadAttention{P, Q, R}
16-
nheads::Int
17-
qkv_layer::P
18-
attn_drop::Q
19-
projection::R
20-
end
21-
22-
@functor MHAttention
23-
24-
function MultiHeadAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
25-
attn_dropout_prob = 0.0, proj_dropout_prob = 0.0)
26-
@assert planes % nheads==0 "planes should be divisible by nheads"
27-
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
28-
attn_drop = Dropout(attn_dropout_prob)
29-
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob))
30-
return MultiHeadAttention(nheads, qkv_layer, attn_drop, proj)
31-
end
32-
33-
function (m::MultiHeadAttention)(x::AbstractArray{T, 3}) where {T}
34-
nfeatures, seq_len, batch_size = size(x)
35-
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
36-
qkv = m.qkv_layer(x_reshaped)
37-
qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size)
38-
query, key, value = chunk(qkv_reshaped, 3; dims = 4)
39-
scale = convert(T, sqrt(size(query, 1) / m.nheads))
40-
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
41-
seq_len * batch_size)
42-
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
43-
m.nheads, seq_len * batch_size)
44-
45-
attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale)
46-
attention = m.attn_drop(attention)
38+
struct MultiHeadAttention
39+
num_heads::Int
40+
qkv_proj
41+
attn_drop
42+
out_proj
43+
end
44+
45+
@functor MultiHeadAttention
46+
47+
function MultiHeadAttention(dims, num_heads::Int;
48+
bias::Bool = false,
49+
# init = glorot_uniform, # TODO
50+
attn_dropout_prob = 0.0,
51+
out_proj_dropout_prob = 0.0)
52+
53+
dims = mha_process_dims(dims)
54+
@assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads"
55+
qkv_proj = QKVProj((dims.q_in, dims.k_in, dims.v_in) => dims.qkv; bias)
56+
attn_drop = Dropout(attn_dropout_prob)
57+
out_proj = Chain(Dense(dims.qkv => dims.out; bias), Dropout(out_proj_dropout_prob))
58+
return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj)
59+
end
60+
61+
mha_process_dims(dims::Int) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
62+
mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair}) = (; q_in = in, k_in = in, v_in = in, qkv, out)
63+
mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair}) = (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out)
64+
65+
# self-attention
66+
(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...)
67+
68+
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, v=:tullio)
69+
## [q_in] = [q_in_dim, q_len, batch_size]
70+
## [k_in] = [k_in_dim, kv_len, batch_size]
71+
## [v_in] = [v_in_dim, kv_len, batch_size]
72+
73+
if v == :tullio
74+
q, k, v = m.qkv_proj(q_in, k_in, v_in, m.num_heads)
75+
# [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size]
76+
# [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size]
4777

48-
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
49-
m.nheads, seq_len * batch_size)
50-
pre_projection = reshape(batched_mul(attention, value_reshaped),
51-
(nfeatures, seq_len, batch_size))
52-
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
53-
return reshape(y, :, seq_len, batch_size)
78+
x, α = dot_product_attention(q, k, v; dropout=m.attn_drop)
79+
x = reshape(x, :, size(x, 3), size(x, 4))
80+
elseif v == :nnalib
81+
q, k, v = m.qkv_proj(q_in, k_in, v_in)
82+
x = NeuralAttentionlib.multihead_qkv_attention(m.num_heads, q, k, v)
83+
else
84+
error("Unknown attention implementation")
85+
end
86+
87+
x = m.out_proj(x)
88+
89+
return x
90+
# return with_weights ? (x, α) : x
5491
end
5592

56-
using Flux, Functors, Test, NNlib, MLUtils
93+
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention
94+
function dot_product_attention(q, k, v; dropout=nothing)
95+
α = dot_product_attention_weights(q, k; dropout)
96+
# [α] = [kv_len, q_len, num_heads, batch_size]
97+
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
98+
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
99+
100+
return x, α
101+
end
57102

58-
mha = MultiHeadAttention(64, 8)
59-
sz = (64, 100, 32)
60-
x = rand(Float32, sz)
61-
y = mha(x)
62-
@test y isa Array{Float32, 3}
63-
@test size(y) == sz
103+
function dot_product_attention_weights(q, k; dropout=nothing)
104+
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
105+
# [α] = [kv_len, q_len, num_heads, batch_size]
106+
α = softmax(α, dims=1)
107+
return dropout === nothing ? α : dropout(α)
108+
end
109+
110+
111+
struct QKVProj
112+
k_proj::Dense
113+
v_proj::Dense
114+
q_proj::Dense
115+
end
116+
117+
@functor QKVProj
118+
119+
function QKVProj((in_dim, qkv_dim)::Pair; bias = false)
120+
q_in_dim, k_in_dim, v_in_dim = in_dim
121+
return QKVProj(
122+
Dense(k_in_dim => qkv_dim; bias),
123+
Dense(v_in_dim => qkv_dim; bias),
124+
Dense(q_in_dim => qkv_dim; bias)
125+
)
126+
end
127+
128+
function (proj::QKVProj)(q_in, k_in, v_in, num_heads)
129+
q = proj.q_proj(q_in)
130+
sz = size(q)
131+
newsz = (sz[1] ÷ num_heads, num_heads, sz[2:end]...)
132+
q = reshape(q, newsz)
133+
k = reshape(proj.k_proj(k_in), newsz)
134+
v = reshape(proj.v_proj(v_in), newsz)
135+
return q, k, v
136+
end
137+
138+
function (proj::QKVProj)(q_in, k_in, v_in)
139+
return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in))
140+
end
141+
142+
143+
function perf(dim, len, batch_size, num_heads)
144+
mha = MultiHeadAttention(dim, num_heads)
145+
x = rand(Float32, (dim, len, batch_size))
146+
147+
y = mha(x, x, x)
148+
@test y isa Array{Float32, 3}
149+
@test size(y) == (dim, len, batch_size)
150+
151+
152+
println("tullio")
153+
@btime $mha($x, v=:tullio);
154+
@btime gradient(m -> sum(m($x, v=:tullio)), $mha);
155+
156+
println("nnalib")
157+
@btime $mha($x, $x, $x, v=:nnalib);
158+
@btime gradient(m -> sum(m($x, v=:nnalib)), $mha);
159+
160+
if CUDA.functional()
161+
mha_gpu = mha |> gpu
162+
x_gpu = x |> gpu
163+
164+
println("tullio - gpu")
165+
@btime $mha_gpu($x_gpu, v=:tullio);
166+
@btime gradient(m -> sum(m($x_gpu, v=:tullio)), $mha_gpu);
167+
168+
println("nnalib - gpu")
169+
@btime CUDA.@sync $mha_gpu($x_gpu, v=:nnalib);
170+
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, v=:nnalib)), $mha_gpu);
171+
end
172+
return nothing
173+
end
174+
175+
perf(64, 100, 32, 8)

0 commit comments

Comments
 (0)