Skip to content

Commit 2be464f

Browse files
[ci skip] mask
1 parent 52bee5e commit 2be464f

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

src/layers/attention.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using NeuralAttentionlib: score_returning
66
using BenchmarkTools
77
using Flux: glorot_uniform
88
using MLUtils
9+
using ChainRulesCore
910
CUDA.allowscalar(false)
1011

1112
const A3{T} = AbstractArray{T, 3}
@@ -112,7 +113,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
112113
if impl == :tullio
113114
x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
114115
elseif impl == :nalib
115-
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v)
116+
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask)
116117
else
117118
error("Unknown attention implementation")
118119
end
@@ -184,11 +185,16 @@ end
184185

185186
function make_causal_mask(x::A3)
186187
d, len, batch_size = size(x)
187-
mask = tril(ones_like(x, (len, len)))
188+
mask = triu(trues_like(x, (len, len)))
188189
return mask
189190
end
190191

192+
trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
193+
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)
194+
191195
@non_differentiable make_causal_mask(x)
196+
@non_differentiable trues_like(::Any...)
197+
@non_differentiable falses_like(::Any...)
192198

193199
function perf(dim, len, batch_size, num_heads)
194200
mha = MultiHeadAttention(dim, num_heads)
@@ -231,7 +237,13 @@ function test(dim, num_heads, len, batch_size)
231237
@test y2 y
232238
@test size(α) == size(α2)
233239
@test α2 α
234-
240+
241+
mask = make_causal_mask(x)
242+
y3, α3 = mha(x; impl=:tullio, with_weights=true, mask)
243+
y4, α4 = mha(x, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())
244+
@test y y2
245+
@test α α2
246+
235247
if CUDA.functional()
236248
mha_gpu = mha |> gpu
237249
x_gpu = x |> gpu
@@ -244,8 +256,7 @@ function test(dim, num_heads, len, batch_size)
244256
return nothing
245257
end
246258

247-
248-
test(4, 2, 2, 1)
259+
test(4, 2, 3, 1)
249260

250261
perf(128, 8, 128, 32)
251262
# tullio
@@ -267,3 +278,9 @@ perf(128, 8, 128, 32)
267278
# 165.109 μs (411 allocations: 18.05 KiB)
268279
# 659.685 μs (1527 allocations: 86.09 KiB)
269280

281+
dim = 2; len = 3; batch_size = 1; num_heads = 1
282+
mha = MultiHeadAttention(dim, num_heads)
283+
x = rand(Float32, (dim, len, batch_size))
284+
mask = make_causal_mask(x)
285+
y, α = mha(x; impl=:tullio, with_weights=true, mask)
286+
y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask())

0 commit comments

Comments
 (0)