Skip to content

Commit 742f2b5

Browse files
[ci skip] updates
1 parent 1b313f2 commit 742f2b5

File tree

1 file changed

+76
-50
lines changed

1 file changed

+76
-50
lines changed

src/layers/attention.jl

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2-
using CUDA, CUDAKernels, KernelAbstractions, LoopVectorization
3-
using Tullio
2+
using CUDA
3+
using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
44
using NeuralAttentionlib
55
using BenchmarkTools
66
CUDA.allowscalar(false)
7+
78
const A3{T} = AbstractArray{T, 3}
9+
const A4{T} = AbstractArray{T, 4}
810

911
"""
1012
MultiHeadAttention(dims, num_heads;
@@ -48,7 +50,8 @@ function MultiHeadAttention(dims, num_heads::Int;
4850
bias::Bool = false,
4951
# init = glorot_uniform, # TODO
5052
attn_dropout_prob = 0.0,
51-
out_proj_dropout_prob = 0.0)
53+
out_proj_dropout_prob = 0.0,
54+
self=false)
5255

5356
dims = mha_process_dims(dims)
5457
@assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads"
@@ -58,48 +61,59 @@ function MultiHeadAttention(dims, num_heads::Int;
5861
return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj)
5962
end
6063

61-
mha_process_dims(dims::Int) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
62-
mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair}) = (; q_in = in, k_in = in, v_in = in, qkv, out)
63-
mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair}) = (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out)
64+
mha_process_dims(dims::Int) =
65+
(; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
66+
67+
mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) =
68+
(; q_in = in, k_in = in, v_in = in, qkv, out)
69+
70+
mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) =
71+
(; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out)
6472

6573
# self-attention
6674
(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...)
6775

68-
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, v=:tullio)
76+
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, impl=:tullio)
6977
## [q_in] = [q_in_dim, q_len, batch_size]
7078
## [k_in] = [k_in_dim, kv_len, batch_size]
7179
## [v_in] = [v_in_dim, kv_len, batch_size]
7280

73-
if v == :tullio
74-
q, k, v = m.qkv_proj(q_in, k_in, v_in, m.num_heads)
75-
# [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size]
76-
# [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size]
77-
78-
x, α = dot_product_attention(q, k, v; dropout=m.attn_drop)
79-
x = reshape(x, :, size(x, 3), size(x, 4))
80-
elseif v == :nnalib
81-
q, k, v = m.qkv_proj(q_in, k_in, v_in)
82-
x = NeuralAttentionlib.multihead_qkv_attention(m.num_heads, q, k, v)
81+
q, k, v = m.qkv_proj(q_in, k_in, v_in)
82+
# [q] = [qkv_dim, q_len, batch_size]
83+
# [k] = [v] = [qkv_dim, kv_len, batch_size]
84+
if impl == :tullio
85+
x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop)
86+
elseif impl == :nnalib
87+
x, α = NeuralAttentionlib.multihead_qkv_attention(
88+
NeuralAttentionlib.score_returning,
89+
m.num_heads, q, k, v)
8390
else
8491
error("Unknown attention implementation")
8592
end
8693

8794
x = m.out_proj(x)
8895

89-
return x
90-
# return with_weights ? (x, α) : x
96+
return with_weights ? (x, α) : x
9197
end
9298

93-
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention
94-
function dot_product_attention(q, k, v; dropout=nothing)
99+
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
100+
function dot_product_attention(q::A4, k::A4, v::A4; dropout=nothing)
95101
α = dot_product_attention_weights(q, k; dropout)
96102
# [α] = [kv_len, q_len, num_heads, batch_size]
97103
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
98104
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
99-
100105
return x, α
101106
end
102107

108+
function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...)
109+
q, k, v = reshape_heads.((q, k, v), num_heads)
110+
x, α = dot_product_attention(q, k, v; kws...)
111+
return flatten_heads(x), α
112+
end
113+
114+
reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
115+
flatten_heads(x) = reshape(x, :, size(x)[3:end]...)
116+
103117
function dot_product_attention_weights(q, k; dropout=nothing)
104118
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
105119
# [α] = [kv_len, q_len, num_heads, batch_size]
@@ -125,16 +139,6 @@ function QKVProj((in_dim, qkv_dim)::Pair; bias = false)
125139
)
126140
end
127141

128-
function (proj::QKVProj)(q_in, k_in, v_in, num_heads)
129-
q = proj.q_proj(q_in)
130-
sz = size(q)
131-
newsz = (sz[1] ÷ num_heads, num_heads, sz[2:end]...)
132-
q = reshape(q, newsz)
133-
k = reshape(proj.k_proj(k_in), newsz)
134-
v = reshape(proj.v_proj(v_in), newsz)
135-
return q, k, v
136-
end
137-
138142
function (proj::QKVProj)(q_in, k_in, v_in)
139143
return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in))
140144
end
@@ -145,51 +149,73 @@ function perf(dim, len, batch_size, num_heads)
145149
x = rand(Float32, (dim, len, batch_size))
146150

147151
println("tullio")
148-
@btime $mha($x, v=:tullio);
149-
@btime gradient(m -> sum(m($x, v=:tullio)), $mha);
152+
@btime $mha($x, impl=:tullio);
153+
@btime gradient(m -> sum(m($x, impl=:tullio)), $mha);
150154

151155
println("nnalib")
152-
@btime $mha($x, $x, $x, v=:nnalib);
153-
@btime gradient(m -> sum(m($x, v=:nnalib)), $mha);
156+
@btime $mha($x, $x, $x, impl=:nnalib);
157+
@btime gradient(m -> sum(m($x, impl=:nnalib)), $mha);
154158

155159
if CUDA.functional()
156160
mha_gpu = mha |> gpu
157161
x_gpu = x |> gpu
158162

159163
println("tullio - gpu")
160-
@btime $mha_gpu($x_gpu, v=:tullio);
161-
@btime gradient(m -> sum(m($x_gpu, v=:tullio)), $mha_gpu);
164+
@btime $mha_gpu($x_gpu, impl=:tullio);
165+
@btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu);
162166

163167
println("nnalib - gpu")
164-
@btime CUDA.@sync $mha_gpu($x_gpu, v=:nnalib);
165-
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, v=:nnalib)), $mha_gpu);
168+
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnalib);
169+
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnalib)), $mha_gpu);
166170
end
167171
return nothing
168172
end
169173

170-
function test(dim, len, batch_size, num_heads)
174+
function test(dim, num_heads, len, batch_size)
171175
mha = MultiHeadAttention(dim, num_heads)
172176
x = rand(Float32, (dim, len, batch_size))
173-
y = mha(x, v=:tullio)
177+
y, α = mha(x, impl=:tullio, with_weights=true)
174178
@test y isa Array{Float32, 3}
175179
@test size(y) == (dim, len, batch_size)
176-
y2 = mha(x, v=:nnalib)
180+
@test α isa Array{Float32, 4}
181+
@test size(α) == (len, len, num_heads, batch_size)
182+
183+
y2, α2 = mha(x, impl=:nnalib, with_weights=true)
177184
@test size(y) == size(y2)
178-
@test y2 y
185+
@test y2 y atol=1e-1
186+
@test size(α) == size(α2)
187+
@test α2 α atol=1e-1
179188

180189
if CUDA.functional()
181190
mha_gpu = mha |> gpu
182191
x_gpu = x |> gpu
183192

184-
y_gpu = mha_gpu(x_gpu, v=:tullio)
185-
y_gpu2 = mha_gpu(x_gpu, v=:nnalib)
186-
@test Array(y_gpu) Array(y_gpu2)
193+
y_gpu = mha_gpu(x_gpu, impl=:tullio)
194+
y_gpu2 = mha_gpu(x_gpu, impl=:nnalib)
195+
@test Array(y_gpu) Array(y_gpu2) atol=1e-1
187196
@test Array(y_gpu) y
188197
end
189198
return nothing
190199
end
191200

192201

193-
test(12, 3, 2, 4)
194-
195-
perf(64, 100, 32, 4)
202+
test(4, 2, 2, 1)
203+
204+
perf(128, 8, 128, 32)
205+
# tullio
206+
# 5.862 ms (85 allocations: 6.75 MiB)
207+
# 14.291 ms (1046 allocations: 17.17 MiB)
208+
# nnalib
209+
# 6.331 ms (90 allocations: 7.75 MiB)
210+
# 16.186 ms (690 allocations: 16.17 MiB)
211+
# tullio - gpu
212+
# 141.365 μs (499 allocations: 22.81 KiB)
213+
# 804.018 μs (2228 allocations: 113.45 KiB)
214+
# nnalib - gpu
215+
# 163.487 μs (410 allocations: 18.02 KiB)
216+
# 673.463 μs (1521 allocations: 84.64 KiB)
217+
218+
dim = 4; num_heads=2; len=2; batch_size=1
219+
mha = MultiHeadAttention(dim, num_heads)
220+
x = rand(Float32, (dim, len, batch_size))
221+
y, α = mha(x, impl=:tullio, with_weights=true)

0 commit comments

Comments
 (0)