@@ -3,6 +3,7 @@ using CUDA
3
3
using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
4
4
using NeuralAttentionlib
5
5
using BenchmarkTools
6
+ using Flux: glorot_uniform
6
7
CUDA. allowscalar (false )
7
8
8
9
const A3{T} = AbstractArray{T, 3 }
48
49
49
50
function MultiHeadAttention (dims, num_heads:: Int ;
50
51
bias:: Bool = false ,
51
- # init = glorot_uniform, # TODO
52
+ init = glorot_uniform,
52
53
attn_dropout_prob = 0.0 ,
53
- out_proj_dropout_prob = 0.0 ,
54
- self= false )
54
+ out_proj_dropout_prob = 0.0 )
55
55
56
56
dims = mha_process_dims (dims)
57
- @assert dims. qkv % num_heads == 0 " qkv_dim should be divisible by num_heads"
58
- qkv_proj = QKVProj (( dims. q_in, dims . k_in, dims . v_in) => dims . qkv ; bias)
57
+ @assert dims. qk % num_heads == 0 " qk_dim should be divisible by num_heads"
58
+ qkv_proj = QKVProj (dims; bias, init )
59
59
attn_drop = Dropout (attn_dropout_prob)
60
- out_proj = Chain (Dense (dims. qkv => dims. out; bias), Dropout (out_proj_dropout_prob))
60
+ out_proj = Chain (Dense (dims. v => dims. out; bias, init ), Dropout (out_proj_dropout_prob))
61
61
return MultiHeadAttention (num_heads, qkv_proj, attn_drop, out_proj)
62
62
end
63
63
64
+ # The following inputs are equivalent:
65
+ # 8
66
+ # 8 => 8 => 8
67
+ # (8, 8, 8) => 8 => 8
68
+ # 8 => (8, 8) => 8
69
+ # (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out
64
70
mha_process_dims (dims:: Int ) =
65
- (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
71
+ (; q_in = dims, k_in = dims, v_in = dims,
72
+ qk = dims, v = dims, out = dims)
66
73
67
74
mha_process_dims ((in, (qkv, out)):: Pair{Int, <:Pair{Int, Int}} ) =
68
- (; q_in = in, k_in = in, v_in = in, qkv, out)
75
+ (; q_in = in, k_in = in, v_in = in,
76
+ qk = qkv, v = qkv, out)
69
77
70
78
mha_process_dims ((in, (qkv, out)):: Pair{<:Tuple, <:Pair{Int, Int}} ) =
71
- (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ], qkv, out)
79
+ (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ],
80
+ qk = qkv, v = qkv, out)
81
+
82
+ mha_process_dims ((in, ((qk, v), out)):: Pair{<:Tuple, <:Pair{<:Tuple, Int}} ) =
83
+ (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ], qk, v, out)
84
+
85
+ mha_process_dims ((in, ((qk, v), out)):: Pair{Int, <:Pair{<:Tuple, Int}} ) =
86
+ (; q_in = in, k_in = in, v_in = in, qk, v, out)
87
+
72
88
73
89
# self-attention
74
90
(m:: MultiHeadAttention )(x; kws... ) = m (x, x, x; kws... )
@@ -79,11 +95,13 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
79
95
# # [v_in] = [v_in_dim, kv_len, batch_size]
80
96
81
97
q, k, v = m. qkv_proj (q_in, k_in, v_in)
82
- # [q] = [qkv_dim, q_len, batch_size]
83
- # [k] = [v] = [qkv_dim, kv_len, batch_size]
98
+ # [q] = [qk_dim, q_len, batch_size]
99
+ # [k] = [qk_dim, kv_len, batch_size]
100
+ # [v] = [v_dim, kv_len, batch_size]
101
+
84
102
if impl == :tullio
85
103
x, α = dot_product_attention (m. num_heads, q, k, v; dropout= m. attn_drop)
86
- elseif impl == :nnalib
104
+ elseif impl == :nalib
87
105
x, α = NeuralAttentionlib. multihead_qkv_attention (
88
106
NeuralAttentionlib. score_returning,
89
107
m. num_heads, q, k, v)
114
132
reshape_heads (x, num_heads) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads, size (x)[2 : end ]. .. )
115
133
flatten_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
116
134
117
- function dot_product_attention_weights (q, k; dropout= nothing )
135
+ function dot_product_attention_weights (q:: A4{T} , k:: A4{T} ;
136
+ dropout= nothing ) where T
137
+ q = q ./ T (√ size (q, 1 ))
118
138
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
119
139
# [α] = [kv_len, q_len, num_heads, batch_size]
120
140
α = softmax (α, dims= 1 )
@@ -123,20 +143,19 @@ end
123
143
124
144
125
145
struct QKVProj
146
+ q_proj:: Dense
126
147
k_proj:: Dense
127
148
v_proj:: Dense
128
- q_proj:: Dense
129
149
end
130
150
131
151
@functor QKVProj
132
152
133
- function QKVProj ((in_dim, qkv_dim):: Pair ; bias = false )
134
- q_in_dim, k_in_dim, v_in_dim = in_dim
153
+ function QKVProj (dims; bias = false , init= glorot_uniform)
135
154
return QKVProj (
136
- Dense (k_in_dim => qkv_dim ; bias),
137
- Dense (v_in_dim => qkv_dim ; bias),
138
- Dense (q_in_dim => qkv_dim ; bias)
139
- )
155
+ Dense (dims . q_in => dims . qk ; bias, init ),
156
+ Dense (dims . k_in => dims . qk ; bias, init ),
157
+ Dense (dims . v_in => dims . v ; bias, init )
158
+ )
140
159
end
141
160
142
161
function (proj:: QKVProj )(q_in, k_in, v_in)
@@ -152,9 +171,9 @@ function perf(dim, len, batch_size, num_heads)
152
171
@btime $ mha ($ x, impl= :tullio );
153
172
@btime gradient (m -> sum (m ($ x, impl= :tullio )), $ mha);
154
173
155
- println (" nnalib " )
156
- @btime $ mha ($ x, $ x, $ x, impl= :nnalib );
157
- @btime gradient (m -> sum (m ($ x, impl= :nnalib )), $ mha);
174
+ println (" nalib " )
175
+ @btime $ mha ($ x, $ x, $ x, impl= :nalib );
176
+ @btime gradient (m -> sum (m ($ x, impl= :nalib )), $ mha);
158
177
159
178
if CUDA. functional ()
160
179
mha_gpu = mha |> gpu
@@ -164,9 +183,9 @@ function perf(dim, len, batch_size, num_heads)
164
183
@btime $ mha_gpu ($ x_gpu, impl= :tullio );
165
184
@btime gradient (m -> sum (m ($ x_gpu, impl= :tullio )), $ mha_gpu);
166
185
167
- println (" nnalib - gpu" )
168
- @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nnalib );
169
- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nnalib )), $ mha_gpu);
186
+ println (" nalib - gpu" )
187
+ @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nalib );
188
+ @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nalib )), $ mha_gpu);
170
189
end
171
190
return nothing
172
191
end
@@ -180,19 +199,19 @@ function test(dim, num_heads, len, batch_size)
180
199
@test α isa Array{Float32, 4 }
181
200
@test size (α) == (len, len, num_heads, batch_size)
182
201
183
- y2, α2 = mha (x, impl= :nnalib , with_weights= true )
202
+ y2, α2 = mha (x, impl= :nalib , with_weights= true )
184
203
@test size (y) == size (y2)
185
- @test y2 ≈ y atol = 1e-1
204
+ @test y2 ≈ y
186
205
@test size (α) == size (α2)
187
- @test α2 ≈ α atol = 1e-1
206
+ @test α2 ≈ α
188
207
189
208
if CUDA. functional ()
190
209
mha_gpu = mha |> gpu
191
210
x_gpu = x |> gpu
192
211
193
212
y_gpu = mha_gpu (x_gpu, impl= :tullio )
194
- y_gpu2 = mha_gpu (x_gpu, impl= :nnalib )
195
- @test Array (y_gpu) ≈ Array (y_gpu2) atol = 1e-1
213
+ y_gpu2 = mha_gpu (x_gpu, impl= :nalib )
214
+ @test Array (y_gpu) ≈ Array (y_gpu2)
196
215
@test Array (y_gpu) ≈ y
197
216
end
198
217
return nothing
@@ -205,17 +224,12 @@ perf(128, 8, 128, 32)
205
224
# tullio
206
225
# 5.862 ms (85 allocations: 6.75 MiB)
207
226
# 14.291 ms (1046 allocations: 17.17 MiB)
208
- # nnalib
227
+ # nalib
209
228
# 6.331 ms (90 allocations: 7.75 MiB)
210
229
# 16.186 ms (690 allocations: 16.17 MiB)
211
230
# tullio - gpu
212
231
# 141.365 μs (499 allocations: 22.81 KiB)
213
232
# 804.018 μs (2228 allocations: 113.45 KiB)
214
- # nnalib - gpu
233
+ # nalib - gpu
215
234
# 163.487 μs (410 allocations: 18.02 KiB)
216
235
# 673.463 μs (1521 allocations: 84.64 KiB)
217
-
218
- dim = 4 ; num_heads= 2 ; len= 2 ; batch_size= 1
219
- mha = MultiHeadAttention (dim, num_heads)
220
- x = rand (Float32, (dim, len, batch_size))
221
- y, α = mha (x, impl= :tullio , with_weights= true )
0 commit comments