Skip to content

Commit aac281d

Browse files
fix tests
1 parent eabcc02 commit aac281d

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

src/attention.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T};
9797
end
9898

9999
"""
100-
make_causal_mask(x)
100+
make_causal_mask(x, dims=2)
101101
102-
Return a boolean square matrix `m` of the same type as `x` and of side `size(x,2)`.
102+
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
103103
Its elements are set such that `m[i, j] == i ≤ j`.
104104
105105
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
106106
"""
107-
function make_causal_mask(x::AbstractArray)
108-
len = size(x, 2)
107+
function make_causal_mask(x::AbstractArray; dims::Int=2)
108+
len = size(x, dims)
109109
mask = triu(trues_like(x, (len, len)))
110110
return mask
111111
end

test/attention.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,57 @@
22
n = 15
33
lenq = 3
44
lenkv = 4
5-
for batch_size in [(), 1, 2, (2,1,3)], num_heads in [1, 3, 5]
5+
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]
66
q = rand(Float32, n, lenq, batch_size...)
77
k = rand(Float32, n, lenkv, batch_size...)
88
v = rand(Float32, n, lenkv, batch_size...)
9-
y, α = dot_product_attention(q, k, v; num_heads)
9+
y, α = dot_product_attention(q, k, v; nheads)
1010
@test y isa Array{Float32}
1111
@test size(y) == (n, lenq, batch_size...)
12-
@test size(α) == (lenkv, lenq, num_heads, batch_size...)
13-
@test sum(α, dims=1) ones(1, lenq, num_heads, batch_size...)
12+
@test size(α) == (lenkv, lenq, nheads, batch_size...)
13+
@test sum(α, dims=1) ones(1, lenq, nheads, batch_size...)
1414
end
1515
end
1616

1717
@testset "dot_product_attention_scores" begin
1818
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
1919
α = dot_product_attention_scores(q, k)
2020
q2, k2 = reshape.((q, k), 8, 3, 1)
21-
y, α2 = dot_product_attention(q2, k2, k2; num_heads=2)
21+
y, α2 = dot_product_attention(q2, k2, k2; nheads=2)
2222
@test α α2
2323
end
2424

2525
@testset "specific results" begin
2626
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
27-
y, α = dot_product_attention(q, k, v; num_heads=2)
27+
y, α = dot_product_attention(q, k, v; nheads=2)
2828
@test y [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;]
2929
@test α [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;]
3030
end
3131

3232
@testset "mask" begin
3333
q = rand(4, 2, 3, 1)
3434
k = rand(4, 2, 5, 1)
35+
3536
mask = rand(Bool, (5, 3))
3637
α = dot_product_attention_scores(q, k; mask)
3738
@test all((α[:,:,1,1].> 0) .== mask)
3839
@test all((α[:,:,2,1].> 0) .== mask)
40+
41+
@testset "causal" begin
42+
x = rand(4, 2, 3, 1)
43+
mask = make_causal_mask(x, dims=3)
44+
α = dot_product_attention_scores(x, x; mask)
45+
@test all((α[:,:,1,1].> 0) .== mask)
46+
@test all((α[:,:,2,1].> 0) .== mask)
47+
48+
α2 = dot_product_attention_scores(x, x; mask=:causal)
49+
@test α2 α
50+
end
51+
end
52+
53+
@testset "dropout" begin
54+
q = k = v = rand(10, 10, 10)
55+
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
56+
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
57+
@test 0.6 > mean(>(0), α) > 0.4
3958
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ include("test_utils.jl")
4040
end
4141

4242
@testset "Attention" begin
43-
include("activations.jl")
43+
include("attention.jl")
4444
end
4545

4646
@testset "Batched Multiplication" begin

0 commit comments

Comments
 (0)