Skip to content

Commit 562c20a

Browse files
committed
Support ArrayInterface.can_avx and speed up tests on Github Actions
1 parent fded771 commit 562c20a

File tree

4 files changed

+46
-36
lines changed

4 files changed

+46
-36
lines changed

src/costs.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,13 +458,16 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
458458
typeof(<<) => :<<,
459459
typeof(>>) => :>>,
460460
typeof(>>>) => :>>>,
461+
typeof(%) => :(%),
462+
typeof(÷) => :(÷),
461463
typeof(Base.ifelse) => :ifelse,
462464
typeof(ifelse) => :ifelse,
463465
typeof(identity) => :identity,
464466
typeof(conj) => :conj
465467
)
466468

467469
# implement whitelist for avx_support that package authors may use to conservatively guard `@avx` application
468-
# for f ∈ keys(FUNCTIONSYMBOLS)
469-
# @eval ArrayInterface.can_avx(::$(typeof(f))) = true
470-
# end
470+
for f keys(FUNCTIONSYMBOLS)
471+
@eval ArrayInterface.can_avx(::$f) = true
472+
end
473+

test/broadcast.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,20 @@
8383
D1 = C .+ A * B;
8484
D2 = @avx C .+ A .*ˡ B;
8585
@test D1 D2
86-
fill!(D2, -999999); D2 = @avx C .+ At' *ˡ B;
87-
@test D1 D2
88-
fill!(D2, -999999); @test A * B (@avx @. D2 = A *ˡ B)
89-
D1 .= view(C, 1, :)' .+ A * B;
90-
fill!(D2, -999999);
91-
@avx D2 .= view(C, 1, :)' .+ A .*ˡ B;
92-
@test D1 D2
93-
C3d = rand(R,3,M,N);
94-
D1 .= view(C3d, 1, :, :) .+ A * B;
95-
fill!(D2, -999999);
96-
@avx D2 .= view(C3d, 1, :, :) .+ A .*ˡ B;
97-
@test D1 D2
98-
86+
if RUN_SLOW_TESTS
87+
fill!(D2, -999999); D2 = @avx C .+ At' *ˡ B;
88+
@test D1 D2
89+
fill!(D2, -999999); @test A * B (@avx @. D2 = A *ˡ B)
90+
D1 .= view(C, 1, :)' .+ A * B;
91+
fill!(D2, -999999);
92+
@avx D2 .= view(C, 1, :)' .+ A .*ˡ B;
93+
@test D1 D2
94+
C3d = rand(R,3,M,N);
95+
D1 .= view(C3d, 1, :, :) .+ A * B;
96+
fill!(D2, -999999);
97+
@avx D2 .= view(C3d, 1, :, :) .+ A .*ˡ B;
98+
@test D1 D2
99+
end
99100
D1 .= 9999;
100101
@avx D2 .= 9999;
101102
@test D1 == D2

test/gemm.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -710,14 +710,21 @@
710710
@test C C2
711711
fill!(C, 9999.999); mulCAtB_2x2blockavx_noinline!(C, A', B);
712712
@test C C2
713-
fill!(C, 9999.999); gemm_accurate!(C, A, B);
714-
@test C C2
715-
fill!(C, 9999.999); gemm_accurate!(C, At', B);
716-
@test C C2
717-
fill!(C, 9999.999); gemm_accurate!(C, A, Bt');
718-
@test C C2
719-
fill!(C, 9999.999); gemm_accurate!(C, At', Bt');
720-
@test C C2
713+
if RUN_SLOW_TESTS
714+
fill!(C, 9999.999); gemm_accurate!(C, A, B);
715+
@test C C2
716+
fill!(C, 9999.999); gemm_accurate!(C, At', B);
717+
@test C C2
718+
fill!(C, 9999.999); gemm_accurate!(C, A, Bt');
719+
@test C C2
720+
fill!(C, 9999.999); gemm_accurate!(C, At', Bt');
721+
@test C C2
722+
Ab = zeros(eltype(C), size(A)); Bb = zeros(eltype(C), size(B)); Cb = zero(C);
723+
threegemms!(Ab, Bb, Cb, A, B, C)
724+
@test Ab C * B'
725+
@test Bb A' * C
726+
@test Cb A * B
727+
end
721728
if iszero(size(A,1) % 8)
722729
Abit = A .> 0.5;
723730
fill!(C, 9999.999); AmulBavx1!(C, Abit, B);
@@ -728,11 +735,6 @@
728735
fill!(C, 9999.999); AmulBavx1!(C, A, Bbit);
729736
@test C A * Bbit
730737
end
731-
Ab = zeros(eltype(C), size(A)); Bb = zeros(eltype(C), size(B)); Cb = zero(C);
732-
threegemms!(Ab, Bb, Cb, A, B, C)
733-
@test Ab C * B'
734-
@test Bb A' * C
735-
@test Cb A * B
736738
end
737739
# exceeds_time_limit() && break
738740
@time @testset "_avx $T dynamic gemm" begin

test/runtests.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ using Test
22
using LoopVectorization
33
using LinearAlgebra
44

5+
import InteractiveUtils
6+
7+
InteractiveUtils.versioninfo(stdout; verbose = true)
8+
59
# const START_TIME = time()
610
# exceeds_time_limit() = (time() - START_TIME) > 35 * 60
711

@@ -30,12 +34,17 @@ Base.IndexStyle(::Type{<:FallbackArrayWrapper}) = IndexLinear()
3034

3135
@show LoopVectorization.REGISTER_COUNT
3236

37+
const RUN_SLOW_TESTS = LoopVectorization.REGISTER_COUNT 16 || !parse(Bool, get(ENV, "GITHUB_ACTIONS", "false"))
38+
@show RUN_SLOW_TESTS
39+
3340
@time @testset "LoopVectorization.jl" begin
3441

3542
@test isempty(detect_unbound_args(LoopVectorization))
3643

3744
@time include("printmethods.jl")
3845

46+
@time include("can_avx.jl")
47+
3948
@time include("fallback.jl")
4049

4150
@time include("utils.jl")
@@ -44,9 +53,7 @@ Base.IndexStyle(::Type{<:FallbackArrayWrapper}) = IndexLinear()
4453

4554
@time include("check_empty.jl")
4655

47-
if isnothing(get(ENV, "TRAVIS_BRANCH", nothing)) || LoopVectorization.REGISTER_COUNT 32 || VERSION v"1.4"
48-
@time include("offsetarrays.jl")
49-
end
56+
@time include("offsetarrays.jl")
5057

5158
@time include("tensors.jl")
5259

@@ -70,8 +77,5 @@ Base.IndexStyle(::Type{<:FallbackArrayWrapper}) = IndexLinear()
7077

7178
@time include("broadcast.jl")
7279

73-
# I test locally on master; times out on Travis.
74-
if isnothing(get(ENV, "TRAVIS_BRANCH", nothing)) || LoopVectorization.REGISTER_COUNT 32 || VERSION v"1.4"
75-
@time include("gemm.jl")
76-
end
80+
@time include("gemm.jl")
7781
end

0 commit comments

Comments
 (0)