Skip to content

Commit d9517ee

Browse files
committed
Fix but in LowDimArray{ntuple(_ -> false, Val(N))}(data), and add vmapreduce to resolve #110.
1 parent d493255 commit d9517ee

File tree

5 files changed

+118
-1
lines changed

5 files changed

+118
-1
lines changed

src/broadcast.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
115115
data::A
116116
end
117117
@inline Base.pointer(A::LowDimArray) = pointer(A.data)
118+
Base.@propagate_inbounds Base.getindex(A::LowDimArray, i...) = getindex(A.data, i...)
118119
Base.size(A::LowDimArray) = Base.size(A.data)
119120
@generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}
120121
s = Expr(:tuple, [Expr(:ref, :strideA, n) for n 1+D[1]:N if D[n]]...)
@@ -125,11 +126,19 @@ end
125126
function LowDimArray{D}(data::A) where {D,T,N,A <: AbstractArray{T,N}}
126127
LowDimArray{D,T,N,A}(data)
127128
end
129+
function extract_all_1_array!(ls::LoopSet, bcname::Symbol, N::Int, elementbytes::Int)
130+
refextract = gensym(bcname)
131+
pushpreamble!(ls, Expr(:(=), refextract, Expr(:ref, bcname, [1 for n 1:N]...)))
132+
return add_constant!(ls, refextract, elementbytes) # or replace elementbytes with sizeof(T) ? u
133+
end
128134
function add_broadcast!(
129135
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
130136
@nospecialize(LDA::Type{<:LowDimArray}), elementbytes::Int
131137
)
132138
D,T,N::Int,_ = LDA.parameters
139+
if !any(D)
140+
return extract_all_1_array!(ls, bcname, N, elementbytes)
141+
end
133142
fulldims = Symbol[loopsyms[n] for n 1:N if D[n]::Bool]
134143
ref = ArrayReference(bcname, fulldims)
135144
add_simple_load!(ls, destname, ref, elementbytes, true, false )::Operation
@@ -139,12 +148,14 @@ function add_broadcast_adjoint_array!(
139148
) where {T,N,A<:AbstractArray{T,N}}
140149
parent = gensym(:parent)
141150
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
151+
# isone(length(loopsyms)) && return extract_all_1_array!(ls, bcname, N, elementbytes)
142152
ref = ArrayReference(parent, Symbol[loopsyms[N + 1 - n] for n 1:N])
143153
add_simple_load!( ls, destname, ref, elementbytes, true, true )::Operation
144154
end
145155
function add_broadcast_adjoint_array!(
146156
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{<:AbstractVector}, elementbytes::Int
147157
)
158+
# isone(length(loopsyms)) && return extract_all_1_array!(ls, bcname, N, elementbytes)
148159
ref = ArrayReference(bcname, Symbol[loopsyms[2]])
149160
add_simple_load!( ls, destname, ref, elementbytes, true, true )
150161
end

src/mapreduce.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
@inline vreduce(::typeof(+), v::VectorizationBase.AbstractSIMDVector) = vsum(v)
3+
@inline vreduce(::typeof(*), v::VectorizationBase.AbstractSIMDVector) = vprod(v)
4+
@inline vreduce(::typeof(max), v::VectorizationBase.AbstractSIMDVector) = vmaximum(v)
5+
@inline vreduce(::typeof(min), v::VectorizationBase.AbstractSIMDVector) = vminimum(v)
6+
@inline vreduce(op, v::VectorizationBase.AbstractSIMDVector) = _vreduce(op, v)
7+
@inline _vreduce(op, v::VectorizationBase.AbstractSIMDVector) = _reduce(op, SVec(v))
8+
@inline function _vreduce(op, v::SVec)
9+
isone(length(v)) && return v[1]
10+
a = op(v[1], v[2])
11+
for i 3:length(v)
12+
a = op(a, v[i])
13+
end
14+
a
15+
end
16+
17+
function mapreduce_simple(f::F, op::OP, args::Vararg{DenseArray{T},A}) where {F,OP,T<:NativeTypes,A}
18+
ptrargs = ntuple(a -> pointer(args[a]), Val(A))
19+
N = length(first(args))
20+
iszero(N) && throw("Length of vector is 0!")
21+
a_0 = f(vload.(ptrargs)...); i = 1
22+
while i < N
23+
a_0 = op(a_0, f(vload.(ptrargs, i)...)); i += 1
24+
end
25+
a_0
26+
end
27+
28+
29+
"""
30+
vmapreduce(f, op, A::DenseArray...)
31+
32+
Vectorized version of `mapreduce`. Applies `f` to each element of the arrays `A`, and reduces the result with `op`.
33+
"""
34+
function vmapreduce(f::F, op::OP, args::Vararg{DenseArray{T},A}) where {F,OP,T<:NativeTypes,A}
35+
N = length(first(args))
36+
A > 1 && @assert all(isequal(length.(args)...))
37+
W = VectorizationBase.pick_vector_width(T)
38+
V = VectorizationBase.pick_vector_width_val(T)
39+
N < W && return mapreduce_simple(f, op, args...)
40+
ptrargs = pointer.(args)
41+
42+
a_0 = f(vload.(V, ptrargs)...); i = W
43+
if N 4W
44+
a_1 = f(vload.(V, ptrargs, i)...); i += W
45+
a_2 = f(vload.(V, ptrargs, i)...); i += W
46+
a_3 = f(vload.(V, ptrargs, i)...); i += W
47+
while i < N - ((W << 2) - 1)
48+
a_0 = op(a_0, f(vload.(V, ptrargs, i)...)); i += W
49+
a_1 = op(a_1, f(vload.(V, ptrargs, i)...)); i += W
50+
a_2 = op(a_2, f(vload.(V, ptrargs, i)...)); i += W
51+
a_3 = op(a_3, f(vload.(V, ptrargs, i)...)); i += W
52+
end
53+
a_0 = op(a_0, a_1)
54+
a_2 = op(a_2, a_3)
55+
a_0 = op(a_0, a_2)
56+
end
57+
while i < N - (W - 1)
58+
a_0 = op(a_0, f(vload.(V, ptrargs, i)...)); i += W
59+
end
60+
if i < N
61+
m = mask(T, N & (W - 1))
62+
a_0 = vifelse(m, op(a_0, f(vload.(V, ptrargs, i)...)), a_0)
63+
end
64+
vreduce(op, a_0)
65+
end
66+
67+
@inline vmapreduce(f, op, args...) = mapreduce(f, op, args...)
68+
69+
70+
"""
71+
vreduce(op, destination, A::DenseArray...)
72+
73+
Vectorized version of `reduce`. Reduces the array `A` using the operator `op`.
74+
"""
75+
@inline vreduce(op, arg) = vmapreduce(identity, op, arg)
76+

test/broadcast.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
fill!(c2, 99999);
2929
@avx @. c2 = a + bl;
3030
@test c1 c2
31+
32+
xs = rand(T, M);
33+
max_ = maximum(xs, dims=1)
34+
@test (@avx exp.(xs .- LowDimArray{(false,)}(max_))) exp.(xs .- LowDimArray{(false,)}(max_))
35+
3136

3237
a = rand(R, M); B = rand(R, M, N); c = rand(R, N); c′ = c';
3338
d1 = @. a + B * c′;

test/gemm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@
634634
@test C 2C2
635635
AmuladdBavx!(C, A, B, -1)
636636
@test C C2
637-
AmuladdBavx!(C, At', B, -2)
637+
AmuladdBavx!(C, At', Bt', -2)
638638
@test C -C2
639639
AmuladdBavx!(C, At', B, 3, 2)
640640
@test C C2

test/mapreduce.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
@testset "mapreduce" begin
3+
4+
for T (Int32, Int64, Float32, Float64)
5+
if T <: Integer
6+
R = T(1):T(100)
7+
x7 = rand(R, 7); y7 = rand(R, 7);
8+
x = rand(R, 127); y = rand(R, 127);
9+
else
10+
x7 = rand(T, 7); y7 = rand(T, 7);
11+
x = rand(T, 127); y = rand(T, 127);
12+
@test vmapreduce(hypot, +, x, y) mapreduce(hypot, +, x, y)
13+
@test vmapreduce(^, (a,b) -> a + b, x7, y7) mapreduce(^, (a,b) -> a + b, x7, y7)
14+
end
15+
@test vreduce(+, x7) sum(x7)
16+
@test vreduce(+, x) sum(x)
17+
@test_throws AssertionError vmapreduce(hypot, +, x7, x)
18+
@test vmapreduce(a -> 2a, *, x) mapreduce(a -> 2a, *, x)
19+
@test vmapreduce(sin, +, x7) mapreduce(sin, +, x7)
20+
@test vmapreduce(log, +, x) mapreduce(log, +, x)
21+
@test vmapreduce(abs2, +, x) mapreduce(abs2, +, x)
22+
end
23+
24+
end
25+

0 commit comments

Comments
 (0)