1
1
using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2
- using CUDA, CUDAKernels, KernelAbstractions, LoopVectorization
3
- using Tullio
2
+ using CUDA
3
+ using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
4
4
using NeuralAttentionlib
5
5
using BenchmarkTools
6
6
CUDA. allowscalar (false )
7
+
7
8
const A3{T} = AbstractArray{T, 3 }
9
+ const A4{T} = AbstractArray{T, 4 }
8
10
9
11
"""
10
12
MultiHeadAttention(dims, num_heads;
@@ -48,7 +50,8 @@ function MultiHeadAttention(dims, num_heads::Int;
48
50
bias:: Bool = false ,
49
51
# init = glorot_uniform, # TODO
50
52
attn_dropout_prob = 0.0 ,
51
- out_proj_dropout_prob = 0.0 )
53
+ out_proj_dropout_prob = 0.0 ,
54
+ self= false )
52
55
53
56
dims = mha_process_dims (dims)
54
57
@assert dims. qkv % num_heads == 0 " qkv_dim should be divisible by num_heads"
@@ -58,48 +61,59 @@ function MultiHeadAttention(dims, num_heads::Int;
58
61
return MultiHeadAttention (num_heads, qkv_proj, attn_drop, out_proj)
59
62
end
60
63
61
- mha_process_dims (dims:: Int ) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
62
- mha_process_dims ((in, (qkv, out)):: Pair{Int, <:Pair} ) = (; q_in = in, k_in = in, v_in = in, qkv, out)
63
- mha_process_dims ((in, (qkv, out)):: Pair{<:Tuple, <:Pair} ) = (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ], qkv, out)
64
+ mha_process_dims (dims:: Int ) =
65
+ (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims)
66
+
67
+ mha_process_dims ((in, (qkv, out)):: Pair{Int, <:Pair{Int, Int}} ) =
68
+ (; q_in = in, k_in = in, v_in = in, qkv, out)
69
+
70
+ mha_process_dims ((in, (qkv, out)):: Pair{<:Tuple, <:Pair{Int, Int}} ) =
71
+ (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ], qkv, out)
64
72
65
73
# self-attention
66
74
(m:: MultiHeadAttention )(x; kws... ) = m (x, x, x; kws... )
67
75
68
- function (m:: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 ; with_weights= false , v = :tullio )
76
+ function (m:: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 ; with_weights= false , impl = :tullio )
69
77
# # [q_in] = [q_in_dim, q_len, batch_size]
70
78
# # [k_in] = [k_in_dim, kv_len, batch_size]
71
79
# # [v_in] = [v_in_dim, kv_len, batch_size]
72
80
73
- if v == :tullio
74
- q, k, v = m. qkv_proj (q_in, k_in, v_in, m. num_heads)
75
- # [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size]
76
- # [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size]
77
-
78
- x, α = dot_product_attention (q, k, v; dropout= m. attn_drop)
79
- x = reshape (x, :, size (x, 3 ), size (x, 4 ))
80
- elseif v == :nnalib
81
- q, k, v = m. qkv_proj (q_in, k_in, v_in)
82
- x = NeuralAttentionlib. multihead_qkv_attention (m. num_heads, q, k, v)
81
+ q, k, v = m. qkv_proj (q_in, k_in, v_in)
82
+ # [q] = [qkv_dim, q_len, batch_size]
83
+ # [k] = [v] = [qkv_dim, kv_len, batch_size]
84
+ if impl == :tullio
85
+ x, α = dot_product_attention (m. num_heads, q, k, v; dropout= m. attn_drop)
86
+ elseif impl == :nnalib
87
+ x, α = NeuralAttentionlib. multihead_qkv_attention (
88
+ NeuralAttentionlib. score_returning,
89
+ m. num_heads, q, k, v)
83
90
else
84
91
error (" Unknown attention implementation" )
85
92
end
86
93
87
94
x = m. out_proj (x)
88
95
89
- return x
90
- # return with_weights ? (x, α) : x
96
+ return with_weights ? (x, α) : x
91
97
end
92
98
93
- # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention
94
- function dot_product_attention (q, k, v; dropout= nothing )
99
+ # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
100
+ function dot_product_attention (q:: A4 , k:: A4 , v:: A4 ; dropout= nothing )
95
101
α = dot_product_attention_weights (q, k; dropout)
96
102
# [α] = [kv_len, q_len, num_heads, batch_size]
97
103
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
98
104
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
99
-
100
105
return x, α
101
106
end
102
107
108
+ function dot_product_attention (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
109
+ q, k, v = reshape_heads .((q, k, v), num_heads)
110
+ x, α = dot_product_attention (q, k, v; kws... )
111
+ return flatten_heads (x), α
112
+ end
113
+
114
+ reshape_heads (x, num_heads) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads, size (x)[2 : end ]. .. )
115
+ flatten_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
116
+
103
117
function dot_product_attention_weights (q, k; dropout= nothing )
104
118
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
105
119
# [α] = [kv_len, q_len, num_heads, batch_size]
@@ -125,16 +139,6 @@ function QKVProj((in_dim, qkv_dim)::Pair; bias = false)
125
139
)
126
140
end
127
141
128
- function (proj:: QKVProj )(q_in, k_in, v_in, num_heads)
129
- q = proj. q_proj (q_in)
130
- sz = size (q)
131
- newsz = (sz[1 ] ÷ num_heads, num_heads, sz[2 : end ]. .. )
132
- q = reshape (q, newsz)
133
- k = reshape (proj. k_proj (k_in), newsz)
134
- v = reshape (proj. v_proj (v_in), newsz)
135
- return q, k, v
136
- end
137
-
138
142
function (proj:: QKVProj )(q_in, k_in, v_in)
139
143
return (proj. q_proj (q_in), proj. k_proj (k_in), proj. v_proj (v_in))
140
144
end
@@ -145,51 +149,73 @@ function perf(dim, len, batch_size, num_heads)
145
149
x = rand (Float32, (dim, len, batch_size))
146
150
147
151
println (" tullio" )
148
- @btime $ mha ($ x, v = :tullio );
149
- @btime gradient (m -> sum (m ($ x, v = :tullio )), $ mha);
152
+ @btime $ mha ($ x, impl = :tullio );
153
+ @btime gradient (m -> sum (m ($ x, impl = :tullio )), $ mha);
150
154
151
155
println (" nnalib" )
152
- @btime $ mha ($ x, $ x, $ x, v = :nnalib );
153
- @btime gradient (m -> sum (m ($ x, v = :nnalib )), $ mha);
156
+ @btime $ mha ($ x, $ x, $ x, impl = :nnalib );
157
+ @btime gradient (m -> sum (m ($ x, impl = :nnalib )), $ mha);
154
158
155
159
if CUDA. functional ()
156
160
mha_gpu = mha |> gpu
157
161
x_gpu = x |> gpu
158
162
159
163
println (" tullio - gpu" )
160
- @btime $ mha_gpu ($ x_gpu, v = :tullio );
161
- @btime gradient (m -> sum (m ($ x_gpu, v = :tullio )), $ mha_gpu);
164
+ @btime $ mha_gpu ($ x_gpu, impl = :tullio );
165
+ @btime gradient (m -> sum (m ($ x_gpu, impl = :tullio )), $ mha_gpu);
162
166
163
167
println (" nnalib - gpu" )
164
- @btime CUDA. @sync $ mha_gpu ($ x_gpu, v = :nnalib );
165
- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, v = :nnalib )), $ mha_gpu);
168
+ @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl = :nnalib );
169
+ @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl = :nnalib )), $ mha_gpu);
166
170
end
167
171
return nothing
168
172
end
169
173
170
- function test (dim, len, batch_size, num_heads )
174
+ function test (dim, num_heads, len, batch_size )
171
175
mha = MultiHeadAttention (dim, num_heads)
172
176
x = rand (Float32, (dim, len, batch_size))
173
- y = mha (x, v = :tullio )
177
+ y, α = mha (x, impl = :tullio , with_weights = true )
174
178
@test y isa Array{Float32, 3 }
175
179
@test size (y) == (dim, len, batch_size)
176
- y2 = mha (x, v= :nnalib )
180
+ @test α isa Array{Float32, 4 }
181
+ @test size (α) == (len, len, num_heads, batch_size)
182
+
183
+ y2, α2 = mha (x, impl= :nnalib , with_weights= true )
177
184
@test size (y) == size (y2)
178
- @test y2 ≈ y
185
+ @test y2 ≈ y atol= 1e-1
186
+ @test size (α) == size (α2)
187
+ @test α2 ≈ α atol= 1e-1
179
188
180
189
if CUDA. functional ()
181
190
mha_gpu = mha |> gpu
182
191
x_gpu = x |> gpu
183
192
184
- y_gpu = mha_gpu (x_gpu, v = :tullio )
185
- y_gpu2 = mha_gpu (x_gpu, v = :nnalib )
186
- @test Array (y_gpu) ≈ Array (y_gpu2)
193
+ y_gpu = mha_gpu (x_gpu, impl = :tullio )
194
+ y_gpu2 = mha_gpu (x_gpu, impl = :nnalib )
195
+ @test Array (y_gpu) ≈ Array (y_gpu2) atol = 1e-1
187
196
@test Array (y_gpu) ≈ y
188
197
end
189
198
return nothing
190
199
end
191
200
192
201
193
- test (12 , 3 , 2 , 4 )
194
-
195
- perf (64 , 100 , 32 , 4 )
202
+ test (4 , 2 , 2 , 1 )
203
+
204
+ perf (128 , 8 , 128 , 32 )
205
+ # tullio
206
+ # 5.862 ms (85 allocations: 6.75 MiB)
207
+ # 14.291 ms (1046 allocations: 17.17 MiB)
208
+ # nnalib
209
+ # 6.331 ms (90 allocations: 7.75 MiB)
210
+ # 16.186 ms (690 allocations: 16.17 MiB)
211
+ # tullio - gpu
212
+ # 141.365 μs (499 allocations: 22.81 KiB)
213
+ # 804.018 μs (2228 allocations: 113.45 KiB)
214
+ # nnalib - gpu
215
+ # 163.487 μs (410 allocations: 18.02 KiB)
216
+ # 673.463 μs (1521 allocations: 84.64 KiB)
217
+
218
+ dim = 4 ; num_heads= 2 ; len= 2 ; batch_size= 1
219
+ mha = MultiHeadAttention (dim, num_heads)
220
+ x = rand (Float32, (dim, len, batch_size))
221
+ y, α = mha (x, impl= :tullio , with_weights= true )
0 commit comments