|
146 | 146 | function dot_product_attention_weights(q::A4{T}, k::A4{T};
|
147 | 147 | dropout=nothing, mask=nothing, bias=nothing) where T
|
148 | 148 |
|
149 |
| - q = q ./ T(√size(q, 1)) |
| 149 | + q = q ./ √T(size(q, 1)) |
150 | 150 | @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
|
151 | 151 | # [α] = [kv_len, q_len, num_heads, batch_size]
|
152 | 152 |
|
|
173 | 173 |
|
174 | 174 | function QKVProj(dims; bias = false, init=glorot_uniform)
|
175 | 175 | return QKVProj(
|
176 |
| - Dense(dims.q_in => dims.qk; bias, init), |
177 |
| - Dense(dims.k_in => dims.qk; bias, init), |
178 |
| - Dense(dims.v_in => dims.v; bias, init) |
179 |
| - ) |
| 176 | + Dense(dims.q_in => dims.qk; bias, init), |
| 177 | + Dense(dims.k_in => dims.qk; bias, init), |
| 178 | + Dense(dims.v_in => dims.v; bias, init)) |
180 | 179 | end
|
181 | 180 |
|
182 | 181 | function (proj::QKVProj)(q_in, k_in, v_in)
|
@@ -224,32 +223,35 @@ function perf(dim, len, batch_size, num_heads)
|
224 | 223 | end
|
225 | 224 |
|
226 | 225 | function test(dim, num_heads, len, batch_size)
|
227 |
| - mha = MultiHeadAttention(dim, num_heads) |
228 |
| - x = rand(Float32, (dim, len, batch_size)) |
229 |
| - y, α = mha(x, impl=:tullio, with_weights=true) |
| 226 | + mha = MultiHeadAttention(dim, num_heads) |
| 227 | + q = rand(Float32, (dim, len, batch_size)) |
| 228 | + k = rand(Float32, (dim, len, batch_size)) |
| 229 | + v = rand(Float32, (dim, len, batch_size)) |
| 230 | + |
| 231 | + y, α = mha(q, k, v, impl=:tullio, with_weights=true) |
230 | 232 | @test y isa Array{Float32, 3}
|
231 | 233 | @test size(y) == (dim, len, batch_size)
|
232 | 234 | @test α isa Array{Float32, 4}
|
233 | 235 | @test size(α) == (len, len, num_heads, batch_size)
|
234 | 236 |
|
235 |
| - y2, α2 = mha(x, impl=:nalib, with_weights=true) |
| 237 | + y2, α2 = mha(q, k, v, impl=:nalib, with_weights=true) |
236 | 238 | @test size(y) == size(y2)
|
237 | 239 | @test y2 ≈ y
|
238 | 240 | @test size(α) == size(α2)
|
239 | 241 | @test α2 ≈ α
|
240 | 242 |
|
241 |
| - mask = make_causal_mask(x) |
242 |
| - y3, α3 = mha(x; impl=:tullio, with_weights=true, mask) |
243 |
| - y4, α4 = mha(x, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) |
244 |
| - @test y ≈ y2 |
245 |
| - @test α ≈ α2 |
| 243 | + mask = make_causal_mask(q) |
| 244 | + y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask) |
| 245 | + y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) |
| 246 | + @test y3 ≈ y4 |
| 247 | + @test α3 ≈ α4 |
246 | 248 |
|
247 | 249 | if CUDA.functional()
|
248 | 250 | mha_gpu = mha |> gpu
|
249 |
| - x_gpu = x |> gpu |
250 |
| - |
251 |
| - y_gpu = mha_gpu(x_gpu, impl=:tullio) |
252 |
| - y_gpu2 = mha_gpu(x_gpu, impl=:nalib) |
| 251 | + q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu |
| 252 | + |
| 253 | + y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio) |
| 254 | + y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib) |
253 | 255 | @test Array(y_gpu) ≈ Array(y_gpu2)
|
254 | 256 | @test Array(y_gpu) ≈ y
|
255 | 257 | end
|
|
0 commit comments