Skip to content

Commit b9fa236

Browse files
[ci skip] add native implementation
1 parent 4d0ada2 commit b9fa236

File tree

2 files changed

+85
-13
lines changed

2 files changed

+85
-13
lines changed

src/layers/attention.jl

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
111111
# [v] = [v_dim, kv_len, batch_size]
112112

113113
if impl == :tullio
114-
x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
114+
x, α = dot_product_attention_tullio(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
115115
elseif impl == :nalib
116116
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask)
117+
elseif impl == :native
118+
x, α = dot_product_attention_native(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
117119
else
118120
error("Unknown attention implementation")
119121
end
@@ -126,24 +128,30 @@ end
126128
reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
127129
flatten_heads(x) = reshape(x, :, size(x)[3:end]...)
128130

129-
function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...)
131+
function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws...)
130132
q, k, v = reshape_heads.((q, k, v), num_heads)
131-
x, α = dot_product_attention(q, k, v; kws...)
133+
x, α = dot_product_attention_tullio(q, k, v; kws...)
134+
return flatten_heads(x), α
135+
end
136+
137+
function dot_product_attention_native(num_heads::Int, q::A3, k::A3, v::A3; kws...)
138+
q, k, v = reshape_heads.((q, k, v), num_heads)
139+
x, α = dot_product_attention_native(q, k, v; kws...)
132140
return flatten_heads(x), α
133141
end
134142

135143
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
136-
function dot_product_attention(q::A4, k::A4, v::A4;
144+
function dot_product_attention_tullio(q::A4, k::A4, v::A4;
137145
dropout=nothing, bias=nothing, mask=nothing)
138146

139-
α = dot_product_attention_weights(q, k; dropout, bias, mask)
147+
α = dot_product_attention_weights_tullio(q, k; dropout, bias, mask)
140148
# [α] = [kv_len, q_len, num_heads, batch_size]
141149
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
142150
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
143151
return x, α
144152
end
145153

146-
function dot_product_attention_weights(q::A4{T}, k::A4{T};
154+
function dot_product_attention_weights_tullio(q::A4{T}, k::A4{T};
147155
dropout=nothing, mask=nothing, bias=nothing) where T
148156

149157
q = q ./ T(size(q, 1))
@@ -162,6 +170,49 @@ function dot_product_attention_weights(q::A4{T}, k::A4{T};
162170
return dropout === nothing ? α : dropout(α)
163171
end
164172

173+
function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
174+
sz = size(x)[3:end]
175+
@assert sz == size(y)[3:end]
176+
x2 = reshape(x, size(x, 1), size(x, 2), :)
177+
y2 = reshape(y, size(y, 1), size(y, 2), :)
178+
z = NNlib.batched_mul(x2, y2)
179+
return reshape(z, size(z, 1), size(z, 2), sz...)
180+
end
181+
182+
function dot_product_attention_native(q::A4, k::A4, v::A4;
183+
dropout=nothing, bias=nothing, mask=nothing)
184+
185+
α = dot_product_attention_weights_native(q, k; dropout, bias, mask)
186+
# [α] = [kv_len, q_len, num_heads, batch_size]
187+
188+
vt = permutedims(v, (1, 3, 2, 4))
189+
x = NNlib.batched_mul(vt, α)
190+
x = permutedims(x, (1, 3, 2, 4))
191+
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
192+
return x, α
193+
end
194+
195+
function dot_product_attention_weights_native(q::A4{T}, k::A4{T};
196+
dropout=nothing, mask=nothing, bias=nothing) where T
197+
198+
q = q ./ T(size(q, 1))
199+
kt = permutedims(k, (3, 1, 2, 4))
200+
qt = permutedims(q, (1, 3, 2, 4))
201+
α = NNlib.batched_mul(kt, qt)
202+
# [α] = [kv_len, q_len, num_heads, batch_size]
203+
204+
if bias !== nothing
205+
α = α .+ bias
206+
end
207+
if mask !== nothing
208+
neginf = typemin(eltype(α))
209+
α = ifelse.(mask, α, neginf)
210+
end
211+
212+
α = softmax(α, dims=1)
213+
return dropout === nothing ? α : dropout(α)
214+
end
215+
165216

166217
struct QKVProj
167218
q_proj::Dense
@@ -206,6 +257,10 @@ function perf(dim, len, batch_size, num_heads)
206257
println("nalib")
207258
@btime $mha($x, $x, $x, impl=:nalib);
208259
@btime gradient(m -> sum(m($x, impl=:nalib)), $mha);
260+
261+
println("native")
262+
@btime $mha($x, $x, $x, impl=:native);
263+
@btime gradient(m -> sum(m($x, impl=:native)), $mha);
209264

210265
if CUDA.functional()
211266
mha_gpu = mha |> gpu
@@ -218,6 +273,10 @@ function perf(dim, len, batch_size, num_heads)
218273
println("nalib - gpu")
219274
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib);
220275
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu);
276+
277+
println("native - gpu")
278+
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:native);
279+
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:native)), $mha_gpu);
221280
end
222281
return nothing
223282
end
@@ -240,6 +299,12 @@ function test(dim, num_heads, len, batch_size)
240299
@test size(α) == size(α2)
241300
@test α2 α
242301

302+
y2b, α2b = mha(q, k, v, impl=:native, with_weights=true)
303+
@test size(y) == size(y2b)
304+
@test y2b y
305+
@test size(α) == size(α2b)
306+
@test α2b α
307+
243308
mask = make_causal_mask(q)
244309
y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask)
245310
y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
@@ -273,16 +338,22 @@ perf(128, 8, 128, 32)
273338
# nalib - 6 threads
274339
# 7.832 ms (187 allocations: 7.76 MiB)
275340
# 29.823 ms (988 allocations: 16.19 MiB)
341+
# native
342+
# 6.269 ms (90 allocations: 9.25 MiB)
343+
# 15.492 ms (1250 allocations: 22.19 MiB)
276344
# tullio - gpu
277345
# 147.746 μs (523 allocations: 24.59 KiB)
278346
# 957.111 μs (2413 allocations: 127.88 KiB)
279347
# nalib - gpu
280348
# 165.109 μs (411 allocations: 18.05 KiB)
281349
# 659.685 μs (1527 allocations: 86.09 KiB)
282-
283-
dim = 2; len = 3; batch_size = 1; num_heads = 1
284-
mha = MultiHeadAttention(dim, num_heads)
285-
x = rand(Float32, (dim, len, batch_size))
286-
mask = make_causal_mask(x)
287-
y, α = mha(x; impl=:tullio, with_weights=true, mask)
288-
y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
350+
# native - gpu
351+
# 158.396 μs (443 allocations: 20.06 KiB)
352+
# 920.633 μs (2308 allocations: 118.78 KiB)
353+
354+
# dim = 2; len = 3; batch_size = 1; num_heads = 1
355+
# mha = MultiHeadAttention(dim, num_heads)
356+
# x = rand(Float32, (dim, len, batch_size))
357+
# mask = make_causal_mask(x)
358+
# y, α = mha(x; impl=:tullio, with_weights=true, mask)
359+
# y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())

test_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#import tensorflow_datasets as tfds # TFDS for MNIST
1010
# %%
1111
x = jnp.arange(16).reshape(1,2,2,4) / 16
12+
alpha = nn.dot_product_attention_weights(x, x)
1213
y = nn.dot_product_attention(x, x, x)
1314
yt = y.transpose((3,2,1,0))
1415

0 commit comments

Comments
 (0)