Skip to content

Commit 7e5bccd

Browse files
committed
Fix for loops without arrays, and improve performance of vmap_multithreaded!. It's a bit of a silly/experimental definition, but...
1 parent 86dc4ac commit 7e5bccd

File tree

5 files changed

+133
-28
lines changed

5 files changed

+133
-28
lines changed

src/lowering.jl

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,10 +699,52 @@ function gc_preserve(ls::LoopSet, q::Expr)
699699
# Expr(:block, gcp)
700700
end
701701

702+
function typeof_outer_reduction_init(ls::LoopSet, op::Operation)
703+
opid = identifier(op)
704+
for (id, sym) ls.preamble_symsym
705+
opid == id && return Expr(:call, :typeof, sym)
706+
end
707+
for (id,intval) ls.preamble_symint
708+
opid == id && return :Int
709+
end
710+
for (id,floatval) ls.preamble_symfloat
711+
opid == id && return :Float64
712+
end
713+
for (id,typ) ls.preamble_zeros
714+
instruction(ops[id]) === LOOPCONSTANT || continue
715+
opid == id || continue
716+
if typ == IntOrFloat
717+
return :Float64
718+
elseif typ == HardInt
719+
return :Int
720+
else#if typ == HardFloat
721+
return :Float64
722+
end
723+
end
724+
throw("Could not find initializing constant.")
725+
end
726+
function typeof_outer_reduction(ls::LoopSet, op::Operation)
727+
for opp operations(ls)
728+
opp === op && continue
729+
name(op) === name(opp) && return typeof_outer_reduction_init(ls, opp)
730+
end
731+
throw("Could not find initialization op.")
732+
end
702733

703-
function determine_eltype(ls::LoopSet)
734+
function determine_eltype(ls::LoopSet)::Union{Symbol,Expr}
704735
if length(ls.includedactualarrays) == 0
705-
return Expr(:call, lv(:typeof), 0)
736+
if length(ls.outer_reductions) == 0
737+
return Expr(:call, lv(:typeof), 0)
738+
elseif length(ls.outer_reductions) == 1
739+
op = ls.operations[only(ls.outer_reductions)]
740+
return typeof_outer_reduction(ls, op)
741+
else
742+
pt = Expr(:call, lv(:promote_type))
743+
for j ls.outer_reductions
744+
push!(pt.args, typeof_outer_reduction(ls, ls.operations[j]))
745+
end
746+
return pt
747+
end
706748
elseif length(ls.includedactualarrays) == 1
707749
return Expr(:call, lv(:eltype), first(ls.includedactualarrays))
708750
end
@@ -783,6 +825,7 @@ end
783825
function define_eltype_vec_width!(q::Expr, ls::LoopSet, vectorized)
784826
push!(q.args, Expr(:(=), ELTYPESYMBOL, determine_eltype(ls)))
785827
push!(q.args, Expr(:(=), VECTORWIDTHSYMBOL, determine_width(ls, vectorized)))
828+
nothing
786829
end
787830
function setup_preamble!(ls::LoopSet, us::UnrollSpecification, Ureduct::Int)
788831
@unpack u₁loopnum, u₂loopnum, vectorizedloopnum, u₁, u₂ = us
@@ -791,7 +834,7 @@ function setup_preamble!(ls::LoopSet, us::UnrollSpecification, Ureduct::Int)
791834
u₂loopsym = order[u₂loopnum]
792835
vectorized = order[vectorizedloopnum]
793836
set_vector_width!(ls, vectorized)
794-
iszero(length(ls.includedactualarrays)) || define_eltype_vec_width!(ls.preamble, ls, vectorized)
837+
iszero(length(ls.includedactualarrays) + length(ls.outer_reductions)) || define_eltype_vec_width!(ls.preamble, ls, vectorized)
795838
lower_licm_constants!(ls)
796839
isone(num_loops(ls)) || pushpreamble!(ls, definemask(getloop(ls, vectorized)))#, u₁ > 1 && u₁loopnum == vectorizedloopnum))
797840
initialize_outer_reductions!(ls, 0, Ureduct, vectorized)

src/map.jl

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,12 @@ function vmap_multithreaded!(
9595
V = VectorizationBase.pick_vector_width_val(T)
9696
Wsh = Wshift + 2
9797
Niter = N >>> Wsh
98-
Base.Threads.@threads :static for j 0:Niter-1
99-
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((j << Wsh,))
100-
vstorent!(ptry, f(vload.(ptrargs, index)...), index)
98+
let Wsh = Wsh, ptry = ptry, ptrargs = ptrargs
99+
Base.Threads.@threads :static for j 0:Niter-1
100+
W = VectorizationBase.pick_vector_width(eltype(ptry))
101+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((j << Wsh,))
102+
vstorent!(ptry, f(vload.(ptrargs, index)...), index)
103+
end
101104
end
102105
ii = Niter << Wsh
103106
while ii < N - (W - 1) # stops at 16 when
@@ -110,34 +113,59 @@ function vmap_multithreaded!(
110113
end
111114
y
112115
end
113-
function vmap_multithreaded!(
116+
struct VmapClosure{F,D,N,A<:Tuple{Vararg{Any,N}}}
117+
f::F
118+
dest::D
119+
args::A
120+
end
121+
(m::VmapClosure)() = vmap_singlethread!(m.f, m.dest, Val{false}(), m.args...)
122+
@generated function vmap_multithreaded!(
114123
f::F,
115124
y::AbstractArray{T},
116125
::Val{false},
117126
args::Vararg{AbstractArray,A}
118127
) where {F,T,A}
119-
N = length(y)
120-
ptry = VectorizationBase.zstridedpointer(y)
121-
ptrargs = VectorizationBase.zstridedpointer.(args)
122-
N > 0 || return y
123-
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
124-
V = VectorizationBase.pick_vector_width_val(T)
125-
Wsh = Wshift + 2
126-
Niter = N >>> Wsh
127-
Base.Threads.@threads :static for j 0:Niter-1
128-
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((j << Wsh,))
129-
vnoaliasstore!(ptry, f(vload.(ptrargs, index)...), index)
130-
end
131-
ii = Niter << Wsh
132-
while ii < N - (W - 1) # stops at 16 when
133-
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),))...), (MM{W}(ii),))
134-
ii = vadd_fast(ii, W)
135-
end
136-
if ii < N
137-
m = mask(T, N & (W - 1))
138-
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),), m)...), (MM{W}(ii),), m)
128+
quote
129+
N = length(y)
130+
nt = min(Threads.nthreads(), $(Sys.CPU_THREADS))
131+
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
132+
(((W * nt < N) & (nt > 1)) && iszero(ccall(:jl_in_threaded_region, Cint, ()))) || return vmap_singlethread!(f, y, Val{false}(), args...)
133+
Nd, Nr = divrem(N >>> Wshift, nt)
134+
Ndb = Nd << Wshift
135+
Ndbr = Ndb + W
136+
Nlast = N - Ndbr * Nr - Ndb * (nt - 1 - Nr)
137+
yfi = firstindex(y);
138+
Base.Cartesian.@nexprs $A a -> begin
139+
args_a = args[a]
140+
argsfi_a = firstindex(args_a);
141+
end
142+
lb = 0
143+
# tasks = Vector{Task}(undef, nt)
144+
# tasks = Base.Cartesian.@ntuple $(Sys.CPU_THREADS) t -> Ref{Task}()
145+
Base.Cartesian.@nexprs $(Sys.CPU_THREADS) j -> begin
146+
# for j ∈ Base.OneTo(nt)
147+
Nlen = j == nt ? Nlast : (j > Nr ? Ndb : Ndbr)
148+
ub = lb + Nlen
149+
yv = view(y, yfi+lb:yfi+ub-1)
150+
argsview = Base.Cartesian.@ntuple $A a -> view(args_a, argsfi_a+lb:argsfi_a+ub-1)
151+
t_j = Task(VmapClosure(f, yv, argsview))
152+
# tasks[j] = t
153+
t_j.sticky = true
154+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t_j, (j == nt ? 0 : j) % Cint)
155+
schedule(t_j)
156+
j == nt && @goto WAIT
157+
lb = ub
158+
end
159+
@label WAIT
160+
Base.Cartesian.@nexprs $(Sys.CPU_THREADS) j -> begin
161+
wait(t_j)
162+
j == nt && return y
163+
end
164+
# for j ∈ Base.OneTo(nt)
165+
# wait(tasks[j])
166+
# end
167+
y
139168
end
140-
y
141169
end
142170

143171
Base.@pure _all_dense(::ArrayInterface.DenseDims{D}) where {D} = all(D)

src/reconstruct_loopset.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ end
393393
# elbytes(::VectorizationBase.AbstractPointer{T}) where {T} = sizeof(T)::Int
394394
typeeltype(::Type{P}) where {T,P<:VectorizationBase.AbstractStridedPointer{T}} = T
395395
typeeltype(::Type{VectorizationBase.FastRange{T,F,S,O}}) where {T,F,S,O} = T
396+
typeeltype(::Type{T}) where {T<:Real} = T
396397
# typeeltype(::Any) = Int8
397398

398399
function add_array_symbols!(ls::LoopSet, arraysymbolinds::Vector{Symbol}, offset::Int)

test/loopinductvars.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
# testset for using them in loops
3+
@testset "Loop Induction Variables" begin
4+
f(x) = cos(x) * log(x)
5+
function avxmax(v)
6+
max_x = -Inf
7+
@avx for i eachindex(v)
8+
x = f(i)
9+
max_x = max(max_x, x)
10+
end
11+
max_x
12+
end
13+
function avxextrema(v)
14+
max_x = -Inf
15+
min_x = Inf
16+
@avx for i eachindex(v)
17+
x = f(i)
18+
max_x = max(max_x, x)
19+
min_x = min(min_x, x)
20+
end
21+
min_x, max_x
22+
end
23+
24+
v = 1:19
25+
minref, maxref = extrema(f, v)
26+
@test maxref avxmax(v)
27+
minavx, maxavx = avxextrema(v);
28+
@test minref minavx
29+
@test maxref maxavx
30+
end
31+

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ const START_TIME = time()
2727

2828
@time include("check_empty.jl")
2929

30+
@time include("loopinductvars.jl")
31+
3032
@time include("zygote.jl")
3133

3234
@time include("offsetarrays.jl")

0 commit comments

Comments
 (0)