@@ -44,26 +44,25 @@ Multi-head dot-product attention layer.
44
44
mha = MultiHeadAttention(64, 8)
45
45
```
46
46
"""
47
- struct MultiHeadAttention
47
+ struct MultiHeadAttention{P1, D, P2}
48
48
num_heads:: Int
49
- qkv_proj
50
- attn_drop
51
- out_proj
49
+ qkv_proj:: P1
50
+ attn_drop:: D
51
+ out_pro :: P2
52
52
end
53
53
54
54
@functor MultiHeadAttention
55
55
56
56
function MultiHeadAttention (dims, num_heads:: Int ;
57
57
bias:: Bool = false ,
58
58
init = glorot_uniform,
59
- attn_dropout_prob = 0.0 ,
60
- out_proj_dropout_prob = 0.0 )
59
+ attn_dropout_prob = 0.0 )
61
60
62
61
dims = mha_process_dims (dims)
63
62
@assert dims. qk % num_heads == 0 " qk_dim should be divisible by num_heads"
64
63
qkv_proj = QKVProj (dims; bias, init)
65
64
attn_drop = Dropout (attn_dropout_prob)
66
- out_proj = Chain ( Dense (dims. v => dims. out; bias, init), Dropout (out_proj_dropout_prob) )
65
+ out_proj = Dense (dims. v => dims. out; bias, init)
67
66
return MultiHeadAttention (num_heads, qkv_proj, attn_drop, out_proj)
68
67
end
69
68
100
99
(m:: MultiHeadAttention )(q, kv; kws... ) = m (q, kv, kv; kws... )
101
100
102
101
function (m:: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 ;
103
- with_weights= false , mask= nothing , impl= :tullio )
102
+ with_weights= false , mask= nothing , impl= :native )
104
103
# # [q_in] = [q_in_dim, q_len, batch_size]
105
104
# # [k_in] = [k_in_dim, kv_len, batch_size]
106
105
# # [v_in] = [v_in_dim, kv_len, batch_size]
@@ -115,7 +114,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
115
114
elseif impl == :nalib
116
115
x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. num_heads, q, k, v, mask)
117
116
elseif impl == :native
118
- x, α = dot_product_attention_native (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
117
+ x, α = dot_product_attention (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
119
118
else
120
119
error (" Unknown attention implementation" )
121
120
end
@@ -134,11 +133,6 @@ function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws..
134
133
return flatten_heads (x), α
135
134
end
136
135
137
- function dot_product_attention_native (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
138
- q, k, v = reshape_heads .((q, k, v), num_heads)
139
- x, α = dot_product_attention_native (q, k, v; kws... )
140
- return flatten_heads (x), α
141
- end
142
136
143
137
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
144
138
function dot_product_attention_tullio (q:: A4 , k:: A4 , v:: A4 ;
@@ -179,23 +173,34 @@ function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where
179
173
return reshape (z, size (z, 1 ), size (z, 2 ), sz... )
180
174
end
181
175
182
- function dot_product_attention_native (q:: A4 , k:: A4 , v:: A4 ;
176
+ function dot_product_attention (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
177
+ q, k, v = reshape_heads .((q, k, v), num_heads)
178
+ x, α = dot_product_attention (q, k, v; kws... )
179
+ return flatten_heads (x), α
180
+ end
181
+
182
+ function dot_product_attention (q:: A4 , k:: A4 , v:: A4 ;
183
183
dropout= nothing , bias= nothing , mask= nothing )
184
184
185
- α = dot_product_attention_weights_native (q, k; dropout, bias, mask)
185
+ α = dot_product_attention_weights (q, k; dropout, bias, mask)
186
186
# [α] = [kv_len, q_len, num_heads, batch_size]
187
187
188
+ # The following permutations and batched_mul are equivalent to
189
+ # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
188
190
vt = permutedims (v, (1 , 3 , 2 , 4 ))
189
191
x = NNlib. batched_mul (vt, α)
190
192
x = permutedims (x, (1 , 3 , 2 , 4 ))
191
193
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
192
194
return x, α
193
195
end
194
196
195
- function dot_product_attention_weights_native (q:: A4{T} , k:: A4{T} ;
197
+ function dot_product_attention_weights (q:: A4{T} , k:: A4{T} ;
196
198
dropout= nothing , mask= nothing , bias= nothing ) where T
197
199
198
200
q = q ./ √ T (size (q, 1 ))
201
+
202
+ # The following permutations and batched_mul are equivalent to
203
+ # @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
199
204
kt = permutedims (k, (3 , 1 , 2 , 4 ))
200
205
qt = permutedims (q, (1 , 3 , 2 , 4 ))
201
206
α = NNlib. batched_mul (kt, qt)
@@ -204,7 +209,11 @@ function dot_product_attention_weights_native(q::A4{T}, k::A4{T};
204
209
if bias != = nothing
205
210
α = α .+ bias
206
211
end
212
+
207
213
if mask != = nothing
214
+ if mask === :causal
215
+ mask = make_causal_mask (α)
216
+ end
208
217
neginf = typemin (eltype (α))
209
218
α = ifelse .(mask, α, neginf)
210
219
end
@@ -329,15 +338,9 @@ perf(128, 8, 128, 32)
329
338
# tullio
330
339
# 5.475 ms (80 allocations: 7.25 MiB)
331
340
# 13.073 ms (1172 allocations: 18.18 MiB)
332
- # tullio - 6 threads
333
- # 4.818 ms (192 allocations: 7.26 MiB)
334
- # 10.927 ms (1398 allocations: 18.19 MiB)
335
341
# nalib
336
342
# 6.040 ms (91 allocations: 7.75 MiB)
337
343
# 14.542 ms (696 allocations: 16.17 MiB)
338
- # nalib - 6 threads
339
- # 7.832 ms (187 allocations: 7.76 MiB)
340
- # 29.823 ms (988 allocations: 16.19 MiB)
341
344
# native
342
345
# 6.269 ms (90 allocations: 9.25 MiB)
343
346
# 15.492 ms (1250 allocations: 22.19 MiB)
@@ -351,9 +354,16 @@ perf(128, 8, 128, 32)
351
354
# 158.396 μs (443 allocations: 20.06 KiB)
352
355
# 920.633 μs (2308 allocations: 118.78 KiB)
353
356
354
- # dim = 2; len = 3; batch_size = 1; num_heads = 1
357
+ # perf(384, 12, 256, 32)
358
+
359
+
360
+ # dim, len, batch_size, num_heads = 128, 8, 128, 32;
361
+ # # dim = 384; len = 128; batch_size = 32; num_heads = 12
355
362
# mha = MultiHeadAttention(dim, num_heads)
356
363
# x = rand(Float32, (dim, len, batch_size))
357
- # mask = make_causal_mask(x)
358
- # y, α = mha(x; impl=:tullio, with_weights=true, mask)
364
+ # @btime mha(x, impl=:tullio);
365
+ # @btime mha(x, impl=:native);
366
+ # @profview mha(x, impl=:tullio);
367
+ # @profview [mha(x, impl=:native) for _ in 1:100];
368
+ # y, α = mha(x; impl=:native, with_weights=true, mask)
359
369
# y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
0 commit comments