Skip to content

Commit 5baf2f0

Browse files
[ci skip] fix tullio impl
1 parent 742f2b5 commit 5baf2f0

File tree

2 files changed

+69
-38
lines changed

2 files changed

+69
-38
lines changed

src/layers/attention.jl

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using CUDA
33
using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
44
using NeuralAttentionlib
55
using BenchmarkTools
6+
using Flux: glorot_uniform
67
CUDA.allowscalar(false)
78

89
const A3{T} = AbstractArray{T, 3}
@@ -48,27 +49,42 @@ end
4849

4950
function MultiHeadAttention(dims, num_heads::Int;
5051
bias::Bool = false,
51-
# init = glorot_uniform, # TODO
52+
init = glorot_uniform,
5253
attn_dropout_prob = 0.0,
53-
out_proj_dropout_prob = 0.0,
54-
self=false)
54+
out_proj_dropout_prob = 0.0)
5555

5656
dims = mha_process_dims(dims)
57-
@assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads"
58-
qkv_proj = QKVProj((dims.q_in, dims.k_in, dims.v_in) => dims.qkv; bias)
57+
@assert dims.qk % num_heads == 0 "qk_dim should be divisible by num_heads"
58+
qkv_proj = QKVProj(dims; bias, init)
5959
attn_drop = Dropout(attn_dropout_prob)
60-
out_proj = Chain(Dense(dims.qkv => dims.out; bias), Dropout(out_proj_dropout_prob))
60+
out_proj = Chain(Dense(dims.v => dims.out; bias, init), Dropout(out_proj_dropout_prob))
6161
return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj)
6262
end
6363

64+
# The following inputs are equivalent:
65+
# 8
66+
# 8 => 8 => 8
67+
# (8, 8, 8) => 8 => 8
68+
# 8 => (8, 8) => 8
69+
# (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out
6470
mha_process_dims(dims::Int) =
65-
(; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
71+
(; q_in = dims, k_in = dims, v_in = dims,
72+
qk = dims, v = dims, out = dims)
6673

6774
mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) =
68-
(; q_in = in, k_in = in, v_in = in, qkv, out)
75+
(; q_in = in, k_in = in, v_in = in,
76+
qk = qkv, v = qkv, out)
6977

7078
mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) =
71-
(; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out)
79+
(; q_in = in[1], k_in = in[2], v_in = in[3],
80+
qk = qkv, v = qkv, out)
81+
82+
mha_process_dims((in, ((qk, v), out))::Pair{<:Tuple, <:Pair{<:Tuple, Int}}) =
83+
(; q_in = in[1], k_in = in[2], v_in = in[3], qk, v, out)
84+
85+
mha_process_dims((in, ((qk, v), out))::Pair{Int, <:Pair{<:Tuple, Int}}) =
86+
(; q_in = in, k_in = in, v_in = in, qk, v, out)
87+
7288

7389
# self-attention
7490
(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...)
@@ -79,11 +95,13 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
7995
## [v_in] = [v_in_dim, kv_len, batch_size]
8096

8197
q, k, v = m.qkv_proj(q_in, k_in, v_in)
82-
# [q] = [qkv_dim, q_len, batch_size]
83-
# [k] = [v] = [qkv_dim, kv_len, batch_size]
98+
# [q] = [qk_dim, q_len, batch_size]
99+
# [k] = [qk_dim, kv_len, batch_size]
100+
# [v] = [v_dim, kv_len, batch_size]
101+
84102
if impl == :tullio
85103
x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop)
86-
elseif impl == :nnalib
104+
elseif impl == :nalib
87105
x, α = NeuralAttentionlib.multihead_qkv_attention(
88106
NeuralAttentionlib.score_returning,
89107
m.num_heads, q, k, v)
@@ -114,7 +132,9 @@ end
114132
reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
115133
flatten_heads(x) = reshape(x, :, size(x)[3:end]...)
116134

117-
function dot_product_attention_weights(q, k; dropout=nothing)
135+
function dot_product_attention_weights(q::A4{T}, k::A4{T};
136+
dropout=nothing) where T
137+
q = q ./ T(size(q, 1))
118138
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
119139
# [α] = [kv_len, q_len, num_heads, batch_size]
120140
α = softmax(α, dims=1)
@@ -123,20 +143,19 @@ end
123143

124144

125145
struct QKVProj
146+
q_proj::Dense
126147
k_proj::Dense
127148
v_proj::Dense
128-
q_proj::Dense
129149
end
130150

131151
@functor QKVProj
132152

133-
function QKVProj((in_dim, qkv_dim)::Pair; bias = false)
134-
q_in_dim, k_in_dim, v_in_dim = in_dim
153+
function QKVProj(dims; bias = false, init=glorot_uniform)
135154
return QKVProj(
136-
Dense(k_in_dim => qkv_dim; bias),
137-
Dense(v_in_dim => qkv_dim; bias),
138-
Dense(q_in_dim => qkv_dim; bias)
139-
)
155+
Dense(dims.q_in => dims.qk; bias, init),
156+
Dense(dims.k_in => dims.qk; bias, init),
157+
Dense(dims.v_in => dims.v; bias, init)
158+
)
140159
end
141160

142161
function (proj::QKVProj)(q_in, k_in, v_in)
@@ -152,9 +171,9 @@ function perf(dim, len, batch_size, num_heads)
152171
@btime $mha($x, impl=:tullio);
153172
@btime gradient(m -> sum(m($x, impl=:tullio)), $mha);
154173

155-
println("nnalib")
156-
@btime $mha($x, $x, $x, impl=:nnalib);
157-
@btime gradient(m -> sum(m($x, impl=:nnalib)), $mha);
174+
println("nalib")
175+
@btime $mha($x, $x, $x, impl=:nalib);
176+
@btime gradient(m -> sum(m($x, impl=:nalib)), $mha);
158177

159178
if CUDA.functional()
160179
mha_gpu = mha |> gpu
@@ -164,9 +183,9 @@ function perf(dim, len, batch_size, num_heads)
164183
@btime $mha_gpu($x_gpu, impl=:tullio);
165184
@btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu);
166185

167-
println("nnalib - gpu")
168-
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnalib);
169-
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnalib)), $mha_gpu);
186+
println("nalib - gpu")
187+
@btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib);
188+
@btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu);
170189
end
171190
return nothing
172191
end
@@ -180,19 +199,19 @@ function test(dim, num_heads, len, batch_size)
180199
@test α isa Array{Float32, 4}
181200
@test size(α) == (len, len, num_heads, batch_size)
182201

183-
y2, α2 = mha(x, impl=:nnalib, with_weights=true)
202+
y2, α2 = mha(x, impl=:nalib, with_weights=true)
184203
@test size(y) == size(y2)
185-
@test y2 y atol=1e-1
204+
@test y2 y
186205
@test size(α) == size(α2)
187-
@test α2 α atol=1e-1
206+
@test α2 α
188207

189208
if CUDA.functional()
190209
mha_gpu = mha |> gpu
191210
x_gpu = x |> gpu
192211

193212
y_gpu = mha_gpu(x_gpu, impl=:tullio)
194-
y_gpu2 = mha_gpu(x_gpu, impl=:nnalib)
195-
@test Array(y_gpu) Array(y_gpu2) atol=1e-1
213+
y_gpu2 = mha_gpu(x_gpu, impl=:nalib)
214+
@test Array(y_gpu) Array(y_gpu2)
196215
@test Array(y_gpu) y
197216
end
198217
return nothing
@@ -205,17 +224,12 @@ perf(128, 8, 128, 32)
205224
# tullio
206225
# 5.862 ms (85 allocations: 6.75 MiB)
207226
# 14.291 ms (1046 allocations: 17.17 MiB)
208-
# nnalib
227+
# nalib
209228
# 6.331 ms (90 allocations: 7.75 MiB)
210229
# 16.186 ms (690 allocations: 16.17 MiB)
211230
# tullio - gpu
212231
# 141.365 μs (499 allocations: 22.81 KiB)
213232
# 804.018 μs (2228 allocations: 113.45 KiB)
214-
# nnalib - gpu
233+
# nalib - gpu
215234
# 163.487 μs (410 allocations: 18.02 KiB)
216235
# 673.463 μs (1521 allocations: 84.64 KiB)
217-
218-
dim = 4; num_heads=2; len=2; batch_size=1
219-
mha = MultiHeadAttention(dim, num_heads)
220-
x = rand(Float32, (dim, len, batch_size))
221-
y, α = mha(x, impl=:tullio, with_weights=true)

test_jax.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#%%
2+
import jax
3+
import jax.numpy as jnp # JAX NumPy
4+
5+
from flax import linen as nn # The Linen API
6+
7+
#import numpy as np # Ordinary NumPy
8+
#import optax # Optimizers
9+
#import tensorflow_datasets as tfds # TFDS for MNIST
10+
# %%
11+
x = jnp.arange(16).reshape(1,2,2,4) / 16
12+
y = nn.dot_product_attention(x, x, x)
13+
yt = y.transpose((3,2,1,0))
14+
15+
yt
16+
yt.shape
17+
# %%

0 commit comments

Comments
 (0)