Skip to content

Commit 31bd1c8

Browse files
support mask = :causal
1 parent b9fa236 commit 31bd1c8

File tree

1 file changed

+36
-26
lines changed

1 file changed

+36
-26
lines changed

src/layers/attention.jl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,25 @@ Multi-head dot-product attention layer.
4444
mha = MultiHeadAttention(64, 8)
4545
```
4646
"""
47-
struct MultiHeadAttention
47+
struct MultiHeadAttention{P1, D, P2}
4848
num_heads::Int
49-
qkv_proj
50-
attn_drop
51-
out_proj
49+
qkv_proj::P1
50+
attn_drop::D
51+
out_pro::P2
5252
end
5353

5454
@functor MultiHeadAttention
5555

5656
function MultiHeadAttention(dims, num_heads::Int;
5757
bias::Bool = false,
5858
init = glorot_uniform,
59-
attn_dropout_prob = 0.0,
60-
out_proj_dropout_prob = 0.0)
59+
attn_dropout_prob = 0.0)
6160

6261
dims = mha_process_dims(dims)
6362
@assert dims.qk % num_heads == 0 "qk_dim should be divisible by num_heads"
6463
qkv_proj = QKVProj(dims; bias, init)
6564
attn_drop = Dropout(attn_dropout_prob)
66-
out_proj = Chain(Dense(dims.v => dims.out; bias, init), Dropout(out_proj_dropout_prob))
65+
out_proj = Dense(dims.v => dims.out; bias, init)
6766
return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj)
6867
end
6968

@@ -100,7 +99,7 @@ end
10099
(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...)
101100

102101
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
103-
with_weights=false, mask=nothing, impl=:tullio)
102+
with_weights=false, mask=nothing, impl=:native)
104103
## [q_in] = [q_in_dim, q_len, batch_size]
105104
## [k_in] = [k_in_dim, kv_len, batch_size]
106105
## [v_in] = [v_in_dim, kv_len, batch_size]
@@ -115,7 +114,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
115114
elseif impl == :nalib
116115
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask)
117116
elseif impl == :native
118-
x, α = dot_product_attention_native(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
117+
x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
119118
else
120119
error("Unknown attention implementation")
121120
end
@@ -134,11 +133,6 @@ function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws..
134133
return flatten_heads(x), α
135134
end
136135

137-
function dot_product_attention_native(num_heads::Int, q::A3, k::A3, v::A3; kws...)
138-
q, k, v = reshape_heads.((q, k, v), num_heads)
139-
x, α = dot_product_attention_native(q, k, v; kws...)
140-
return flatten_heads(x), α
141-
end
142136

143137
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
144138
function dot_product_attention_tullio(q::A4, k::A4, v::A4;
@@ -179,23 +173,34 @@ function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where
179173
return reshape(z, size(z, 1), size(z, 2), sz...)
180174
end
181175

182-
function dot_product_attention_native(q::A4, k::A4, v::A4;
176+
function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...)
177+
q, k, v = reshape_heads.((q, k, v), num_heads)
178+
x, α = dot_product_attention(q, k, v; kws...)
179+
return flatten_heads(x), α
180+
end
181+
182+
function dot_product_attention(q::A4, k::A4, v::A4;
183183
dropout=nothing, bias=nothing, mask=nothing)
184184

185-
α = dot_product_attention_weights_native(q, k; dropout, bias, mask)
185+
α = dot_product_attention_weights(q, k; dropout, bias, mask)
186186
# [α] = [kv_len, q_len, num_heads, batch_size]
187187

188+
# The following permutations and batched_mul are equivalent to
189+
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
188190
vt = permutedims(v, (1, 3, 2, 4))
189191
x = NNlib.batched_mul(vt, α)
190192
x = permutedims(x, (1, 3, 2, 4))
191193
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
192194
return x, α
193195
end
194196

195-
function dot_product_attention_weights_native(q::A4{T}, k::A4{T};
197+
function dot_product_attention_weights(q::A4{T}, k::A4{T};
196198
dropout=nothing, mask=nothing, bias=nothing) where T
197199

198200
q = q ./ T(size(q, 1))
201+
202+
# The following permutations and batched_mul are equivalent to
203+
# @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
199204
kt = permutedims(k, (3, 1, 2, 4))
200205
qt = permutedims(q, (1, 3, 2, 4))
201206
α = NNlib.batched_mul(kt, qt)
@@ -204,7 +209,11 @@ function dot_product_attention_weights_native(q::A4{T}, k::A4{T};
204209
if bias !== nothing
205210
α = α .+ bias
206211
end
212+
207213
if mask !== nothing
214+
if mask === :causal
215+
mask = make_causal_mask(α)
216+
end
208217
neginf = typemin(eltype(α))
209218
α = ifelse.(mask, α, neginf)
210219
end
@@ -329,15 +338,9 @@ perf(128, 8, 128, 32)
329338
# tullio
330339
# 5.475 ms (80 allocations: 7.25 MiB)
331340
# 13.073 ms (1172 allocations: 18.18 MiB)
332-
# tullio - 6 threads
333-
# 4.818 ms (192 allocations: 7.26 MiB)
334-
# 10.927 ms (1398 allocations: 18.19 MiB)
335341
# nalib
336342
# 6.040 ms (91 allocations: 7.75 MiB)
337343
# 14.542 ms (696 allocations: 16.17 MiB)
338-
# nalib - 6 threads
339-
# 7.832 ms (187 allocations: 7.76 MiB)
340-
# 29.823 ms (988 allocations: 16.19 MiB)
341344
# native
342345
# 6.269 ms (90 allocations: 9.25 MiB)
343346
# 15.492 ms (1250 allocations: 22.19 MiB)
@@ -351,9 +354,16 @@ perf(128, 8, 128, 32)
351354
# 158.396 μs (443 allocations: 20.06 KiB)
352355
# 920.633 μs (2308 allocations: 118.78 KiB)
353356

354-
# dim = 2; len = 3; batch_size = 1; num_heads = 1
357+
# perf(384, 12, 256, 32)
358+
359+
360+
# dim, len, batch_size, num_heads = 128, 8, 128, 32;
361+
# # dim = 384; len = 128; batch_size = 32; num_heads = 12
355362
# mha = MultiHeadAttention(dim, num_heads)
356363
# x = rand(Float32, (dim, len, batch_size))
357-
# mask = make_causal_mask(x)
358-
# y, α = mha(x; impl=:tullio, with_weights=true, mask)
364+
# @btime mha(x, impl=:tullio);
365+
# @btime mha(x, impl=:native);
366+
# @profview mha(x, impl=:tullio);
367+
# @profview [mha(x, impl=:native) for _ in 1:100];
368+
# y, α = mha(x; impl=:native, with_weights=true, mask)
359369
# y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())

0 commit comments

Comments
 (0)