Skip to content

Commit bda3516

Browse files
committed
Added check_args(::AbstractStridedPointer) method.
1 parent 964ae0d commit bda3516

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

src/condense_loopset.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ Additionally, define `pointer` and `stride` methods.
275275
check_args(parent(A)) # PermutedDimsArray, NamedDimsArray
276276
end
277277
end
278+
@inline check_args(A::VectorizationBase.AbstractStridedPointer{T}) where {T} = check_type(T)
278279
@inline check_args(A, Bs...) = check_args(A) && check_args(Bs...)
279280
"""
280281
check_type(::Type{T}) where {T}

test/miscellaneous.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,17 @@ using Test
1919
dot3(x, A, y) = dot(x, A, y)
2020
end
2121
function dot3avx(x, A, y)
22-
M, N = size(A)
2322
s = zero(promote_type(eltype(x), eltype(A), eltype(y)))
24-
@avx for m 1:M, n 1:N
23+
@avx for m axes(A,1), n axes(A,2)
2524
s += x[m] * A[m,n] * y[n]
2625
end
2726
s
2827
end
2928
function dot3v2avx(x, A, y)
30-
M, N = size(A)
3129
s = zero(promote_type(eltype(x), eltype(A), eltype(y)))
32-
@avx for n 1:N
30+
@avx for n axes(A,2)
3331
t = zero(s)
34-
for m 1:M
32+
for m axes(A,1)
3533
t += x[m] * A[m,n]
3634
end
3735
s += t * y[n]
@@ -788,8 +786,9 @@ end
788786
M, N = 47, 73;
789787
x = rand(T, M); A = rand(T, M, N); y = rand(T, N);
790788
d3 = dot3(x, A, y)
791-
@test dot3avx(x, A, y) d3
792-
@test dot3v2avx(x, A, y) d3
789+
@test dot3avx(LoopVectorization.stridedpointer(x), A, y) d3
790+
@test dot3v2avx(x, A, LoopVectorization.stridedpointer(y)) d3
791+
@test dot3avx24(x, A, y) d3
793792
@test dot3_avx(x, A, y) d3
794793

795794
A2 = similar(A);
@@ -930,15 +929,15 @@ end
930929
@test X1 X2
931930
@test Y1 Y2
932931

933-
# a_re, a_im = rand(T, 2, 2, 2), rand(T, 2, 2, 2);
934-
# b_re, b_im = rand(T, 2, 2), rand(T, 2, 2);
935-
# c_re_1 = ones(T, 2, 2); c_re_2 = ones(T, 2, 2);
936-
# multiple_unrolls_split_depchains!(c_re_1, a_re, b_re, a_im, b_im, true) # [1 1; 1 1]
937-
# multiple_unrolls_split_depchains_avx!(c_re_2, a_re, b_re, a_im, b_im, true) # [1 1; 1 1]
938-
# @test c_re_1 ≈ c_re_2
939-
# multiple_unrolls_split_depchains!(c_re_1, a_re, b_re, a_im, b_im) # [1 1; 1 1]
940-
# multiple_unrolls_split_depchains_avx!(c_re_2, a_re, b_re, a_im, b_im) # [1 1; 1 1]
941-
# @test c_re_1 ≈ c_re_2
932+
a_re, a_im = rand(T, 2, 2, 2), rand(T, 2, 2, 2);
933+
b_re, b_im = rand(T, 2, 2), rand(T, 2, 2);
934+
c_re_1 = ones(T, 2, 2); c_re_2 = ones(T, 2, 2);
935+
multiple_unrolls_split_depchains!(c_re_1, a_re, b_re, a_im, b_im, true) # [1 1; 1 1]
936+
multiple_unrolls_split_depchains_avx!(c_re_2, a_re, b_re, a_im, b_im, true) # [1 1; 1 1]
937+
@test c_re_1 c_re_2
938+
multiple_unrolls_split_depchains!(c_re_1, a_re, b_re, a_im, b_im) # [1 1; 1 1]
939+
multiple_unrolls_split_depchains_avx!(c_re_2, a_re, b_re, a_im, b_im) # [1 1; 1 1]
940+
@test c_re_1 c_re_2
942941

943942
@test loopinductvardivision(X1) loopinductvardivisionavx(X2)
944943

0 commit comments

Comments
 (0)