Skip to content

Commit e212b6b

Browse files
[ci skip] updates
1 parent 6e7f538 commit e212b6b

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

src/layers/attention.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ end
146146
function dot_product_attention_weights(q::A4{T}, k::A4{T};
147147
dropout=nothing, mask=nothing, bias=nothing) where T
148148

149-
q = q ./ T(size(q, 1))
149+
q = q ./ T(size(q, 1))
150150
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
151151
# [α] = [kv_len, q_len, num_heads, batch_size]
152152

@@ -173,10 +173,9 @@ end
173173

174174
function QKVProj(dims; bias = false, init=glorot_uniform)
175175
return QKVProj(
176-
Dense(dims.q_in => dims.qk; bias, init),
177-
Dense(dims.k_in => dims.qk; bias, init),
178-
Dense(dims.v_in => dims.v; bias, init)
179-
)
176+
Dense(dims.q_in => dims.qk; bias, init),
177+
Dense(dims.k_in => dims.qk; bias, init),
178+
Dense(dims.v_in => dims.v; bias, init))
180179
end
181180

182181
function (proj::QKVProj)(q_in, k_in, v_in)
@@ -224,32 +223,35 @@ function perf(dim, len, batch_size, num_heads)
224223
end
225224

226225
function test(dim, num_heads, len, batch_size)
227-
mha = MultiHeadAttention(dim, num_heads)
228-
x = rand(Float32, (dim, len, batch_size))
229-
y, α = mha(x, impl=:tullio, with_weights=true)
226+
mha = MultiHeadAttention(dim, num_heads)
227+
q = rand(Float32, (dim, len, batch_size))
228+
k = rand(Float32, (dim, len, batch_size))
229+
v = rand(Float32, (dim, len, batch_size))
230+
231+
y, α = mha(q, k, v, impl=:tullio, with_weights=true)
230232
@test y isa Array{Float32, 3}
231233
@test size(y) == (dim, len, batch_size)
232234
@test α isa Array{Float32, 4}
233235
@test size(α) == (len, len, num_heads, batch_size)
234236

235-
y2, α2 = mha(x, impl=:nalib, with_weights=true)
237+
y2, α2 = mha(q, k, v, impl=:nalib, with_weights=true)
236238
@test size(y) == size(y2)
237239
@test y2 y
238240
@test size(α) == size(α2)
239241
@test α2 α
240242

241-
mask = make_causal_mask(x)
242-
y3, α3 = mha(x; impl=:tullio, with_weights=true, mask)
243-
y4, α4 = mha(x, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
244-
@test y y2
245-
@test α α2
243+
mask = make_causal_mask(q)
244+
y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask)
245+
y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
246+
@test y3 y4
247+
@test α3 α4
246248

247249
if CUDA.functional()
248250
mha_gpu = mha |> gpu
249-
x_gpu = x |> gpu
250-
251-
y_gpu = mha_gpu(x_gpu, impl=:tullio)
252-
y_gpu2 = mha_gpu(x_gpu, impl=:nalib)
251+
q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu
252+
253+
y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio)
254+
y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib)
253255
@test Array(y_gpu) Array(y_gpu2)
254256
@test Array(y_gpu) y
255257
end

0 commit comments

Comments
 (0)