Skip to content

Commit 739c14f

Browse files
fix tests on julia 1.6
1 parent 0139d42 commit 739c14f

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

test/attention.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ end
2525
@testset "specific results" begin
2626
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
2727
y, α = dot_product_attention(q, k, v; nheads=2)
28-
@test y [0.4297536645089624 0.46431026790247376 0.49773020657887745; 0.5130869978422957 0.5476436012358071 0.5810635399122107; 0.6137914555895531 0.6478764227436047 0.6804545876711346; 0.6971247889228864 0.731209756076938 0.763787921004468;;;]
29-
@test α [0.3138955704910261 0.264431440679808 0.21921458153690657; 0.3329478654910607 0.32820631493296265 0.31838021718955445; 0.35315656401791323 0.4073622443872293 0.4624052012735389;;; 0.2886914482847165 0.24123865285082136 0.19843756756539277; 0.33124273666190807 0.3238934260675431 0.31176110185581074; 0.3800658150533755 0.43486792108163547 0.4898013305787966;;;;]
28+
ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468]
29+
ytrue = reshape(ytrue, 4, 3, 1)
30+
αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966]
31+
αtrue = reshapetrue, 3, 3, 2, 1)
32+
@test y ytrue
33+
@test α αtrue
3034
end
3135

3236
@testset "mask" begin

0 commit comments

Comments
 (0)