@@ -111,9 +111,11 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
111
111
# [v] = [v_dim, kv_len, batch_size]
112
112
113
113
if impl == :tullio
114
- x, α = dot_product_attention (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
114
+ x, α = dot_product_attention_tullio (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
115
115
elseif impl == :nalib
116
116
x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. num_heads, q, k, v, mask)
117
+ elseif impl == :native
118
+ x, α = dot_product_attention_native (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
117
119
else
118
120
error (" Unknown attention implementation" )
119
121
end
@@ -126,24 +128,30 @@ end
126
128
reshape_heads (x, num_heads) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads, size (x)[2 : end ]. .. )
127
129
flatten_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
128
130
129
- function dot_product_attention (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
131
+ function dot_product_attention_tullio (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
130
132
q, k, v = reshape_heads .((q, k, v), num_heads)
131
- x, α = dot_product_attention (q, k, v; kws... )
133
+ x, α = dot_product_attention_tullio (q, k, v; kws... )
134
+ return flatten_heads (x), α
135
+ end
136
+
137
+ function dot_product_attention_native (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
138
+ q, k, v = reshape_heads .((q, k, v), num_heads)
139
+ x, α = dot_product_attention_native (q, k, v; kws... )
132
140
return flatten_heads (x), α
133
141
end
134
142
135
143
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
136
- function dot_product_attention (q:: A4 , k:: A4 , v:: A4 ;
144
+ function dot_product_attention_tullio (q:: A4 , k:: A4 , v:: A4 ;
137
145
dropout= nothing , bias= nothing , mask= nothing )
138
146
139
- α = dot_product_attention_weights (q, k; dropout, bias, mask)
147
+ α = dot_product_attention_weights_tullio (q, k; dropout, bias, mask)
140
148
# [α] = [kv_len, q_len, num_heads, batch_size]
141
149
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
142
150
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
143
151
return x, α
144
152
end
145
153
146
- function dot_product_attention_weights (q:: A4{T} , k:: A4{T} ;
154
+ function dot_product_attention_weights_tullio (q:: A4{T} , k:: A4{T} ;
147
155
dropout= nothing , mask= nothing , bias= nothing ) where T
148
156
149
157
q = q ./ √ T (size (q, 1 ))
@@ -162,6 +170,49 @@ function dot_product_attention_weights(q::A4{T}, k::A4{T};
162
170
return dropout === nothing ? α : dropout (α)
163
171
end
164
172
173
+ function NNlib. batched_mul (x:: AbstractArray{T1,N} , y:: AbstractArray{T2,N} ) where {T1,T2,N}
174
+ sz = size (x)[3 : end ]
175
+ @assert sz == size (y)[3 : end ]
176
+ x2 = reshape (x, size (x, 1 ), size (x, 2 ), :)
177
+ y2 = reshape (y, size (y, 1 ), size (y, 2 ), :)
178
+ z = NNlib. batched_mul (x2, y2)
179
+ return reshape (z, size (z, 1 ), size (z, 2 ), sz... )
180
+ end
181
+
182
+ function dot_product_attention_native (q:: A4 , k:: A4 , v:: A4 ;
183
+ dropout= nothing , bias= nothing , mask= nothing )
184
+
185
+ α = dot_product_attention_weights_native (q, k; dropout, bias, mask)
186
+ # [α] = [kv_len, q_len, num_heads, batch_size]
187
+
188
+ vt = permutedims (v, (1 , 3 , 2 , 4 ))
189
+ x = NNlib. batched_mul (vt, α)
190
+ x = permutedims (x, (1 , 3 , 2 , 4 ))
191
+ # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
192
+ return x, α
193
+ end
194
+
195
+ function dot_product_attention_weights_native (q:: A4{T} , k:: A4{T} ;
196
+ dropout= nothing , mask= nothing , bias= nothing ) where T
197
+
198
+ q = q ./ √ T (size (q, 1 ))
199
+ kt = permutedims (k, (3 , 1 , 2 , 4 ))
200
+ qt = permutedims (q, (1 , 3 , 2 , 4 ))
201
+ α = NNlib. batched_mul (kt, qt)
202
+ # [α] = [kv_len, q_len, num_heads, batch_size]
203
+
204
+ if bias != = nothing
205
+ α = α .+ bias
206
+ end
207
+ if mask != = nothing
208
+ neginf = typemin (eltype (α))
209
+ α = ifelse .(mask, α, neginf)
210
+ end
211
+
212
+ α = softmax (α, dims= 1 )
213
+ return dropout === nothing ? α : dropout (α)
214
+ end
215
+
165
216
166
217
struct QKVProj
167
218
q_proj:: Dense
@@ -206,6 +257,10 @@ function perf(dim, len, batch_size, num_heads)
206
257
println (" nalib" )
207
258
@btime $ mha ($ x, $ x, $ x, impl= :nalib );
208
259
@btime gradient (m -> sum (m ($ x, impl= :nalib )), $ mha);
260
+
261
+ println (" native" )
262
+ @btime $ mha ($ x, $ x, $ x, impl= :native );
263
+ @btime gradient (m -> sum (m ($ x, impl= :native )), $ mha);
209
264
210
265
if CUDA. functional ()
211
266
mha_gpu = mha |> gpu
@@ -218,6 +273,10 @@ function perf(dim, len, batch_size, num_heads)
218
273
println (" nalib - gpu" )
219
274
@btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nalib );
220
275
@btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nalib )), $ mha_gpu);
276
+
277
+ println (" native - gpu" )
278
+ @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :native );
279
+ @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :native )), $ mha_gpu);
221
280
end
222
281
return nothing
223
282
end
@@ -240,6 +299,12 @@ function test(dim, num_heads, len, batch_size)
240
299
@test size (α) == size (α2)
241
300
@test α2 ≈ α
242
301
302
+ y2b, α2b = mha (q, k, v, impl= :native , with_weights= true )
303
+ @test size (y) == size (y2b)
304
+ @test y2b ≈ y
305
+ @test size (α) == size (α2b)
306
+ @test α2b ≈ α
307
+
243
308
mask = make_causal_mask (q)
244
309
y3, α3 = mha (q, k, v; impl= :tullio , with_weights= true , mask)
245
310
y4, α4 = mha (q, k, v, impl= :nalib , with_weights= true , mask= NeuralAttentionlib. CausalMask ())
@@ -273,16 +338,22 @@ perf(128, 8, 128, 32)
273
338
# nalib - 6 threads
274
339
# 7.832 ms (187 allocations: 7.76 MiB)
275
340
# 29.823 ms (988 allocations: 16.19 MiB)
341
+ # native
342
+ # 6.269 ms (90 allocations: 9.25 MiB)
343
+ # 15.492 ms (1250 allocations: 22.19 MiB)
276
344
# tullio - gpu
277
345
# 147.746 μs (523 allocations: 24.59 KiB)
278
346
# 957.111 μs (2413 allocations: 127.88 KiB)
279
347
# nalib - gpu
280
348
# 165.109 μs (411 allocations: 18.05 KiB)
281
349
# 659.685 μs (1527 allocations: 86.09 KiB)
282
-
283
- dim = 2 ; len = 3 ; batch_size = 1 ; num_heads = 1
284
- mha = MultiHeadAttention (dim, num_heads)
285
- x = rand (Float32, (dim, len, batch_size))
286
- mask = make_causal_mask (x)
287
- y, α = mha (x; impl= :tullio , with_weights= true , mask)
288
- y2, α2 = mha (x; impl= :nalib , with_weights= true , mask= NeuralAttentionlib. CausalMask ())
350
+ # native - gpu
351
+ # 158.396 μs (443 allocations: 20.06 KiB)
352
+ # 920.633 μs (2308 allocations: 118.78 KiB)
353
+
354
+ # dim = 2; len = 3; batch_size = 1; num_heads = 1
355
+ # mha = MultiHeadAttention(dim, num_heads)
356
+ # x = rand(Float32, (dim, len, batch_size))
357
+ # mask = make_causal_mask(x)
358
+ # y, α = mha(x; impl=:tullio, with_weights=true, mask)
359
+ # y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
0 commit comments