Skip to content

Commit 0ec9a00

Browse files
move multiheadattention from Metalhead
1 parent 4da339e commit 0ec9a00

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

src/layers/attention.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
3+
attn_dropout_prob = 0., proj_dropout_prob = 0.)
4+
5+
Multi-head self-attention layer.
6+
7+
# Arguments
8+
9+
- `planes`: number of input channels
10+
- `nheads`: number of heads
11+
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
12+
- `attn_dropout_prob`: dropout probability after the self-attention layer
13+
- `proj_dropout_prob`: dropout probability after the projection layer
14+
"""
15+
struct MultiHeadAttention{P, Q, R}
16+
nheads::Int
17+
qkv_layer::P
18+
attn_drop::Q
19+
projection::R
20+
end
21+
22+
@functor MHAttention
23+
24+
function MultiHeadAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
25+
attn_dropout_prob = 0.0, proj_dropout_prob = 0.0)
26+
@assert planes % nheads==0 "planes should be divisible by nheads"
27+
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
28+
attn_drop = Dropout(attn_dropout_prob)
29+
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob))
30+
return MultiHeadAttention(nheads, qkv_layer, attn_drop, proj)
31+
end
32+
33+
function (m::MultiHeadAttention)(x::AbstractArray{T, 3}) where {T}
34+
nfeatures, seq_len, batch_size = size(x)
35+
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
36+
qkv = m.qkv_layer(x_reshaped)
37+
qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size)
38+
query, key, value = chunk(qkv_reshaped, 3; dims = 4)
39+
scale = convert(T, sqrt(size(query, 1) / m.nheads))
40+
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
41+
seq_len * batch_size)
42+
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
43+
m.nheads, seq_len * batch_size)
44+
45+
attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale)
46+
attention = m.attn_drop(attention)
47+
48+
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
49+
m.nheads, seq_len * batch_size)
50+
pre_projection = reshape(batched_mul(attention, value_reshaped),
51+
(nfeatures, seq_len, batch_size))
52+
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
53+
return reshape(y, :, seq_len, batch_size)
54+
end
55+
56+
using Flux, Functors, Test, NNlib, MLUtils
57+
58+
mha = MultiHeadAttention(64, 8)
59+
sz = (64, 100, 32)
60+
x = rand(Float32, sz)
61+
y = mha(x)
62+
@test y isa Array{Float32, 3}
63+
@test size(y) == sz

0 commit comments

Comments
 (0)