@@ -6,6 +6,7 @@ using NeuralAttentionlib: score_returning
6
6
using BenchmarkTools
7
7
using Flux: glorot_uniform
8
8
using MLUtils
9
+ using ChainRulesCore
9
10
CUDA. allowscalar (false )
10
11
11
12
const A3{T} = AbstractArray{T, 3 }
@@ -112,7 +113,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
112
113
if impl == :tullio
113
114
x, α = dot_product_attention (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
114
115
elseif impl == :nalib
115
- x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. num_heads, q, k, v)
116
+ x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. num_heads, q, k, v, mask )
116
117
else
117
118
error (" Unknown attention implementation" )
118
119
end
@@ -184,11 +185,16 @@ end
184
185
185
186
function make_causal_mask (x:: A3 )
186
187
d, len, batch_size = size (x)
187
- mask = tril ( ones_like (x, (len, len)))
188
+ mask = triu ( trues_like (x, (len, len)))
188
189
return mask
189
190
end
190
191
192
+ trues_like (x:: AbstractArray , sz= size (x)) = fill! (similar (x, Bool, sz), true )
193
+ falses_like (x:: AbstractArray , sz= size (x)) = fill! (similar (x, Bool, sz), false )
194
+
191
195
@non_differentiable make_causal_mask (x)
196
+ @non_differentiable trues_like (:: Any... )
197
+ @non_differentiable falses_like (:: Any... )
192
198
193
199
function perf (dim, len, batch_size, num_heads)
194
200
mha = MultiHeadAttention (dim, num_heads)
@@ -231,7 +237,13 @@ function test(dim, num_heads, len, batch_size)
231
237
@test y2 ≈ y
232
238
@test size (α) == size (α2)
233
239
@test α2 ≈ α
234
-
240
+
241
+ mask = make_causal_mask (x)
242
+ y3, α3 = mha (x; impl= :tullio , with_weights= true , mask)
243
+ y4, α4 = mha (x, impl= :nalib , with_weights= true , mask= NeuralAttentionlib. CausalMask ())
244
+ @test y ≈ y2
245
+ @test α ≈ α2
246
+
235
247
if CUDA. functional ()
236
248
mha_gpu = mha |> gpu
237
249
x_gpu = x |> gpu
@@ -244,8 +256,7 @@ function test(dim, num_heads, len, batch_size)
244
256
return nothing
245
257
end
246
258
247
-
248
- test (4 , 2 , 2 , 1 )
259
+ test (4 , 2 , 3 , 1 )
249
260
250
261
perf (128 , 8 , 128 , 32 )
251
262
# tullio
@@ -267,3 +278,9 @@ perf(128, 8, 128, 32)
267
278
# 165.109 μs (411 allocations: 18.05 KiB)
268
279
# 659.685 μs (1527 allocations: 86.09 KiB)
269
280
281
+ dim = 2 ; len = 3 ; batch_size = 1 ; num_heads = 1
282
+ mha = MultiHeadAttention (dim, num_heads)
283
+ x = rand (Float32, (dim, len, batch_size))
284
+ mask = make_causal_mask (x)
285
+ y, α = mha (x; impl= :tullio , with_weights= true , mask)
286
+ y2, α2 = mha (x; impl= :nalib , with_weights= true , mask= NeuralAttentionlib. CausalMask ())
0 commit comments