|
2 | 2 | n = 15
|
3 | 3 | lenq = 3
|
4 | 4 | lenkv = 4
|
5 |
| - for batch_size in [(), 1, 2, (2,1,3)], num_heads in [1, 3, 5] |
| 5 | + for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] |
6 | 6 | q = rand(Float32, n, lenq, batch_size...)
|
7 | 7 | k = rand(Float32, n, lenkv, batch_size...)
|
8 | 8 | v = rand(Float32, n, lenkv, batch_size...)
|
9 |
| - y, α = dot_product_attention(q, k, v; num_heads) |
| 9 | + y, α = dot_product_attention(q, k, v; nheads) |
10 | 10 | @test y isa Array{Float32}
|
11 | 11 | @test size(y) == (n, lenq, batch_size...)
|
12 |
| - @test size(α) == (lenkv, lenq, num_heads, batch_size...) |
13 |
| - @test sum(α, dims=1) ≈ ones(1, lenq, num_heads, batch_size...) |
| 12 | + @test size(α) == (lenkv, lenq, nheads, batch_size...) |
| 13 | + @test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) |
14 | 14 | end
|
15 | 15 | end
|
16 | 16 |
|
17 | 17 | @testset "dot_product_attention_scores" begin
|
18 | 18 | q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
|
19 | 19 | α = dot_product_attention_scores(q, k)
|
20 | 20 | q2, k2 = reshape.((q, k), 8, 3, 1)
|
21 |
| - y, α2 = dot_product_attention(q2, k2, k2; num_heads=2) |
| 21 | + y, α2 = dot_product_attention(q2, k2, k2; nheads=2) |
22 | 22 | @test α ≈ α2
|
23 | 23 | end
|
24 | 24 |
|
25 | 25 | @testset "specific results" begin
|
26 | 26 | q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
|
27 |
| - y, α = dot_product_attention(q, k, v; num_heads=2) |
| 27 | + y, α = dot_product_attention(q, k, v; nheads=2) |
28 | 28 | @test y ≈ [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;]
|
29 | 29 | @test α ≈ [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;]
|
30 | 30 | end
|
31 | 31 |
|
32 | 32 | @testset "mask" begin
|
33 | 33 | q = rand(4, 2, 3, 1)
|
34 | 34 | k = rand(4, 2, 5, 1)
|
| 35 | + |
35 | 36 | mask = rand(Bool, (5, 3))
|
36 | 37 | α = dot_product_attention_scores(q, k; mask)
|
37 | 38 | @test all((α[:,:,1,1].> 0) .== mask)
|
38 | 39 | @test all((α[:,:,2,1].> 0) .== mask)
|
| 40 | + |
| 41 | + @testset "causal" begin |
| 42 | + x = rand(4, 2, 3, 1) |
| 43 | + mask = make_causal_mask(x, dims=3) |
| 44 | + α = dot_product_attention_scores(x, x; mask) |
| 45 | + @test all((α[:,:,1,1].> 0) .== mask) |
| 46 | + @test all((α[:,:,2,1].> 0) .== mask) |
| 47 | + |
| 48 | + α2 = dot_product_attention_scores(x, x; mask=:causal) |
| 49 | + @test α2 ≈ α |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +@testset "dropout" begin |
| 54 | + q = k = v = rand(10, 10, 10) |
| 55 | + fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) |
| 56 | + y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) |
| 57 | + @test 0.6 > mean(>(0), α) > 0.4 |
39 | 58 | end
|
0 commit comments