Skip to content

Commit 5e54996

Browse files
committed
Some progress on threading
1 parent 7473acc commit 5e54996

File tree

5 files changed

+282
-33
lines changed

5 files changed

+282
-33
lines changed

src/codegen/lower_load.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ function lower_load!(
285285
if (suffix != -1) && ls.loadelimination[]
286286
if (u₁ > 1) & (u₂max > 1)
287287
istr, ispl = isoptranslation(ls, op, UnrollSymbols(u₁loopsym, u₂loopsym, vloopsym))
288-
if istr 0
288+
if istr 0x00
289289
return lower_load_for_optranslation!(q, op, ispl, ls, td, mask, istr)
290290
end
291291
end

src/codegen/lower_threads.jl

Lines changed: 248 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,40 @@ function (::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V})(p::Ptr{UInt}) where {UNROLL,OPS,A
1010
ThreadingUtilities.store!(p, ret, 7)
1111
nothing
1212
end
13+
@generated function Base.pointer(::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
14+
f = AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()
15+
precompile(f, (Ptr{UInt},))
16+
quote
17+
$(Expr(:meta,:inline))
18+
@cfunction($f, Cvoid, (Ptr{UInt},))
19+
end
20+
end
21+
22+
function launch!(p::Ptr{UInt}, fptr::Ptr{Cvoid}, args::Tuple{LB,V}) where {LB,V}
23+
offset = ThreadingUtilities.store!(p, fptr, 0)
24+
offset = ThreadingUtilities.store!(p, args, offset)
25+
nothing
26+
end
27+
function launch(
28+
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, lb::LB, vargs::V, tid
29+
) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
30+
p = ThreadingUtilities.taskpointer(tid)
31+
f = AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()
32+
fptr = pointer(f)
33+
while true
34+
if ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.SPIN, ThreadingUtilities.STUP)
35+
launch!(p, fptr, (lb,vargs))
36+
@assert ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.STUP, ThreadingUtilities.TASK)
37+
return
38+
elseif ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.WAIT, ThreadingUtilities.STUP)
39+
launch!(p, fptr, (lb,vargs))
40+
@assert ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.STUP, ThreadingUtilities.LOCK)
41+
ThreadingUtilities.wake_thread!(tid % UInt)
42+
return
43+
end
44+
ThreadingUtilities.pause()
45+
end
46+
end
1347

1448
# function approx_cbrt(x)
1549
# s = significand(x)
@@ -18,21 +52,226 @@ end
1852
# # 40 + 0.00020833333333333335*(x-64000) -2.1701388888888896e-9*(x-64000)^2*0.5 + 5.6514033564814844e-14 * (x-64000)^3/6
1953
# end
2054

21-
function choose_threads(::StaticInt{C}, x) where {C}
55+
function choose_num_threads(::StaticInt{C}, x) where {C}
2256
nt = ifelse(gt(num_threads(), num_cores()), num_cores(), num_threads())
2357
fx = Base.uitofp(Float64, x)
24-
min(Base.fptosi(Int, Base.ceil_llvm(5.0852672001495816e-11*C*Base.sqrt_llvm(fx))), nt)
58+
min(Base.fptoui(UInt, Base.ceil_llvm(5.0852672001495816e-11*C*Base.sqrt_llvm(fx))), UInt(nt))
59+
end
60+
function push_loop_length_expr!(q::Expr, ls::LoopSet)
61+
l = 1
62+
ndynamic = 0
63+
mulexpr = length(ls.loops) == 1 ? q : Expr(:call, lv(:vmul_fast))
64+
for loop ls.loops
65+
if isstaticloop(loop)
66+
l *= length(loop)
67+
else
68+
ndynamic += 1
69+
if ndynamic < 3
70+
push!(mulexpr.args, loop.lensym)
71+
else
72+
mulexpr = Expr(:call, lv(:vmul_fast), mulexpr, loop.lensym)
73+
end
74+
end
75+
end
76+
if length(ls.loops) == 1
77+
ndynamic == 0 && push!(q.args, l)
78+
elseif l == 1
79+
push!(q.args, mulexpr)
80+
elseif ndynamic == 0
81+
push!(q.args, l)
82+
elseif ndynamic == 1
83+
push!(mulexpr.args, l)
84+
push!(q.args, mulexpr)
85+
else
86+
push!(q.args, Expr(:call, :vmul_fast, mulexpr, l))
87+
end
88+
nothing
89+
end
90+
function divrem_fast(numerator, denominator)
91+
d = Base.udiv_int(numerator, denominator)
92+
r = numerator - denominator*d
93+
d, r
94+
end
95+
96+
function outer_reduct_combine_expressions(ls::LoopSet, retv)
97+
q = Expr(:block, :(var"#load#thread#ret#" = ThreadingUtilities.store!(var"#thread#ptr#", typeof($retv), 7)))
98+
for (i,or) enumerate(ls.outer_reductions)
99+
op = ls.operations[or]
100+
var = name(op)
101+
mvar = mangledvar(op)
102+
instr = instruction(op)
103+
out = Symbol(mvar, "##onevec##")
104+
instrcall = callexp(instr)
105+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), out))
106+
if length(ls.outer_reductions) > 1
107+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Expr(:call, GlobalRef(Core, :getfield), Symbol("#load#thread#ret#"), i, false)))
108+
else
109+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Symbol("#load#thread#ret#")))
110+
end
111+
push!(q.args, Expr(:(=), out, Expr(:call, :data, instrcall)))
112+
end
113+
q
114+
end
115+
116+
function thread_loop_summary!(ls, threadedloop::Loop, u₁loop::Loop, u₂loop::Loop, vloop::Loop, issecondthreadloop::Bool)
117+
threadloopnumtag = Int(issecondthreadloop)
118+
lensym = Symbol("#len#thread#$threadloopnumtag#")
119+
define_len = if isstaticloop(threadedloop)
120+
:($lensym = $(length(threadedloop)))
121+
else
122+
:($lensym = $((threadedloop.lensym)))
123+
end
124+
unroll_factor = 1
125+
if threadedloop === vloop
126+
unroll_factor *= W
127+
end
128+
if threadedloop === u₁loop
129+
unroll_factor *= u₁
130+
elseif threadedloop === u₂loop
131+
unroll_factor *= u₂
132+
end
133+
num_unroll_sym = Symbol("#num#unrolls#thread#$threadloopnumtag#")
134+
define_num_unrolls = if unroll_factor == 1
135+
:($num_unroll_sym = $lensym)
136+
else
137+
:($num_unroll_sym = Base.udiv_int($lensym, $(UInt(unroll_factor))))
138+
end
139+
iterstart_sym = Symbol("#iter#start#$threadloopnumtag#")
140+
iterstop_sym = Symbol("#iter#stop#$threadloopnumtag#")
141+
blksz_sym = Symbol("#nblock#size#thread#$threadloopnumtag#")
142+
loopstart = if isknown(first(threadedloop))
143+
:($iterstart_sym = $(gethint(first(threadedloop))))
144+
else
145+
:($iterstart_sym = $(getsym(first(threadedloop))))
146+
end
147+
if isknown(step(threadedloop))
148+
mf = gethint(threadedloop) * unroll_factor
149+
if isone(mf)
150+
iterstop = :($iterstop_sym = $iterstart_sym + $blksz_sym)
151+
looprange = :(CloseOpen($iterstart_sym, $iterstop_sym))
152+
lastrange = if isknown(last(threadedloop))
153+
:(CloseOpen($iterstart_sym,$(gethint(threadedloop)+1)))
154+
else # we want all the intervals to have the same type.
155+
:(CloseOpen($iterstart_sym,$(getsym(threadedloop))+1))
156+
end
157+
else
158+
iterstop = :($iterstop_sym = $iterstart_sym + $blksz_sym * $mf)
159+
looprange = :($iterstart_sym:StaticInt{$mf}():$iterstop_sym-1)
160+
lastrange = if isknown(last(threadedloop))
161+
:($iterstart_sym:StaticInt{$mf}():$(gethint(threadedloop)))
162+
else
163+
:($iterstart_sym:StaticInt{$mf}():$(getsym(threadedloop)))
164+
end
165+
end
166+
else
167+
stepthread_sym = Symbol("#step#thread#$threadloopnumtag#")
168+
pushpreamble!(ls, :($stepthread_sym = $unroll_factor * $(getsym(step(threadedloop)))))
169+
iterstop = :($iterstop_sym = $iterstart_sym + $blksz_sym * $stepthread_sym)
170+
looprange = :($iterstart_sym:$stepthread_sym:$iterstop_sym-1)
171+
lastrange = if isknown(last(threadedloop))
172+
:($iterstart_sym:$stepthread_sym:$(gethint(threadedloop)))
173+
else
174+
:($iterstart_sym:$stepthread_sym:$(getsym(threadedloop)))
175+
end
176+
end
177+
define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange
25178
end
26179

27-
function thread_single_loop_expr(ls::LoopSet, UNROLL, id)
180+
function thread_single_loop_expr(ls::LoopSet, ua::UnrollArgs, valid_thread_loop, c, UNROLL, OPS, ARF, AM, LPSYM)
181+
choose_nthread = :(choose_num_threads(StaticInt{$c}()))
182+
push_loop_length_expr!(choose_nthread, ls)
183+
threadedid = findfirst(valid_thread_loop)::Int
184+
@unpack u₁loop, u₂loop, vloop, u₁, u₂ = ua
185+
W = ls.vector_width[]
186+
threadedloop = getloop(ls, threadedid)
187+
define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange = thread_loop_summary!(ls, threadedloop, u₁loop, u₂loop, vloop, 0)
188+
loopboundexpr = Expr(:tuple)
189+
lastboundexpr = Expr(:tuple)
190+
for (i,loop) enumerate(threadedloop)
191+
if loop === threadedloop
192+
push!(loopboundexpr.args, looprange)
193+
push!(lastboundexpr.args, lastrange)
194+
else
195+
loop_boundary!(loopboundexpr, loop)
196+
loop_boundary!(lastboundexpr, loop)
197+
end
198+
end
199+
_avx_call_ = :(_avx_!(Val{$UNROLL}(), Val{$OPS}(), Val{$ARF}(), Val{$AM}(), Val{$LPSYM}(), $lastboundexpr, var"#vargs#"))
200+
update_return_values = if length(ls.outer_reductions) > 0
201+
retv = loopset_return_value(ls, Val(false))
202+
_avx_call_ = Expr(:(=), retv, _avx_call_)
203+
outer_reduct_combine_expressions(ls, retv)
204+
else
205+
nothing
206+
end
207+
q = quote
208+
var"#nthreads#" = $choose_nthread # UInt
209+
$define_len % UInt
210+
$define_num_unrolls
211+
var"#nthreads#" = Base.min(var"#nthreads#", $num_unrolls)
212+
var"#nrequest#" = (var"#nthreads#" % UInt32) - 0x00000001
213+
var"#nrequest#" == 0x00000000 && return LoopVectorization._avx_!(Val{$UNROLL}(), Val{$OPS}, Val{$ARF}(), Val{$AM}(), Val{$LPSYM}(), var"#lv#tuple#args#")
214+
var"#threads#", var"#torelease#" = LoopVectorization._request_threads(Threads.threadid(), var"#nrequest#")
28215

216+
var"#base#block#size#thread#0#", var"#nrem#thread#" = LoopVectorization.divrem_fast(num_unrolls, var"#nthreads#")
217+
$loopstart
218+
219+
var"#thread#launch#count#" = 0x00000000
220+
var"#thread#id#" = 0x00000000
221+
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
222+
var"#threads#remain#" = true
223+
while var"#threads#remain#"
224+
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
225+
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
226+
var"#thread#launch#count#" += 0x00000001
227+
var"#nblock#size#thread#0#" = Core.ifelse(
228+
var"#thread#launch#count#" < (var"#nrem#thread#" % Base.typeof(var"#threadid#")),
229+
var"#base#block#size#thread#0#" + Base.one(var"#base#block#size#thread#0#"),
230+
var"#base#block#size#thread#0#"
231+
)
232+
var"#trailzing#zeros#" += 0x00000001
233+
$iterstop
234+
var"#thread#id#" += var"#trailzing#zeros#"
235+
236+
LoopVectorization.launch(
237+
Val{$UNROLL}(), Val{$OPS}(), Val{$ARF}(), Val{$AM}(), Val{$LPSYM}(),
238+
$loopboundexpr, var"#vargs#", var"#thread#id#"
239+
)
240+
241+
var"#thread#mask#" >>>= var"#trailzing#zeros#"
242+
243+
var"#iter#start#0#" = var"#iter#stop#0#"
244+
var"#threads#remain#" = var"#thread#launch#count#" var"$nrequest#"
245+
end
246+
$_avx_call_
247+
var"#thread#id#" = 0x00000000
248+
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
249+
var"#threads#remain#" = true
250+
while var"#threads#remain#"
251+
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
252+
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
253+
var"#trailzing#zeros#" += 0x00000001
254+
var"#thread#mask#" >>>= var"#trailzing#zeros#"
255+
var"#thread#id#" += var"#trailzing#zeros#"
256+
var"#thread#ptr#" = ThreadingUtilities.taskpointer(var"#thread#id#")
257+
ThreadingUtilities.__wait(var"#thread#ptr#")
258+
$update_return_values
259+
var"#threads#remain#" = var"#thread#mask#" 0x00000000
260+
end
261+
CheapThreads.free_threads!(var"#torelease#")
262+
end
263+
length(ls.outer_reductions) > 0 ? push!(q.args, retv) : push!(q.args, nothing)
264+
q
29265
end
30266
function thread_multiple_loop_expr(ls::LoopSet, UNROLL, valid_thread_loop)
31267

32268
end
33269

34-
function avx_threads_expr(ls::LoopSet, UNROLL)
270+
function valid_thread_loops(ls::LoopSet)
35271
order, u₁loop, u₂loop, vectorized, u₁, u₂, c, shouldinline = choose_order_cost(ls)
272+
# NOTE: `names` are being placed in the opposite order here versus normal lowering!
273+
copyto!(names(ls), order); init_loop_map!(ls)
274+
ua = UnrollArgs(getloop(ls, u₁loop), getloop(ls, u₂loop), getloop(ls, vloop), u₁, u₂, u₂)
36275
valid_thread_loop = fill(true, length(order))
37276
for op operations(ls)
38277
if isstore(op) && (length(reduceddependencies(op)) > 0)
@@ -45,6 +284,10 @@ function avx_threads_expr(ls::LoopSet, UNROLL)
45284
end
46285
end
47286
end
287+
valid_thread_loop, ua, c
288+
end
289+
function avx_threads_expr(ls::LoopSet, UNROLL)
290+
valid_thread_loop, us, c = valid_thread_loops(ls)
48291
num_candiates = sum(valid_thread_loop)
49292
# num_to_thread = min(num_candiates, 2)
50293
# candidate_ids =
@@ -54,8 +297,7 @@ function avx_threads_expr(ls::LoopSet, UNROLL)
54297
thread_single_loop_expr(ls, UNROLL, findfirst(isone, valid_thread_loop)::Int)
55298
else
56299
thread_multiple_loop_expr(ls, UNROLL, vald_thread_loop)
57-
end
58-
300+
end
59301
end
60302

61303

src/modeling/determinestrategy.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -665,28 +665,33 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
665665
end
666666
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
667667
@unpack u₁loopsym, u₂loopsym, vloopsym = unrollsyms
668-
(vloopsym == u₁loopsym || vloopsym == u₂loopsym) && return 0, false
669-
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, false
668+
(vloopsym == u₁loopsym || vloopsym == u₂loopsym) && return 0, 0x00
669+
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, 0x00
670670
u₁step = step(getloop(ls, u₁loopsym))
671671
u₂step = step(getloop(ls, u₂loopsym))
672-
(isknown(u₁step) & isknown(u₂step)) || return 0, false
673-
abs(gethint(u₁step)) == abs(gethint(u₂step)) || return 0, false
672+
(isknown(u₁step) & isknown(u₂step)) || return 0, 0x00
673+
abs(gethint(u₁step)) == abs(gethint(u₂step)) || return 0, 0x00
674674

675675
istranslation = 0
676676
inds = getindices(op); li = op.ref.loopedindex
677677
for i eachindex(li)
678678
if !li[i]
679679
opp = findparent(ls, inds[i + (first(inds) === DISCONTIGUOUS)])
680680
if isu₁unrolled(opp) && isu₂unrolled(opp)
681-
isadd = instruction(opp).instr === :(+)
682-
issub = instruction(opp).instr === :(-)
683-
if isadd | issub
684-
return i, isadd
681+
if instruction(opp).instr === :(+)
682+
return i, 0x03 # 00000011 - both positive
683+
elseif instruction(opp).instr === :(-)
684+
oppp1 = parents(opp)[1]
685+
if isu₁unrolled(oppp1)
686+
return i, 0x01 # 00000001 - u₁ positive, u₂ negative
687+
else#isu₂unrolled(oppp1)
688+
return i, 0x02 # 00000010 - u₂ positive, u₁ negative
689+
end
685690
end
686691
end
687692
end
688693
end
689-
0, false
694+
0, 0x00
690695
end
691696
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
692697
mno = typemin(Int)

src/parse/add_compute.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ function search_tree(opv::Vector{Operation}, var::Symbol) # relies on cycles bei
8787
false
8888
end
8989

90-
function update_for_ref_reduction!()
91-
if varname(mpref) === var
92-
id = findfirst(r -> r == mpref.mref, ls.refs_aliasing_syms)
93-
mpref.varname = var = id === nothing ? var : ls.syms_aliasing_refs[id]
94-
reduction_ind = ind
95-
mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
96-
else
97-
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
98-
end
99-
end
90+
# function update_for_ref_reduction!()
91+
# if varname(mpref) === var
92+
# id = findfirst(r -> r == mpref.mref, ls.refs_aliasing_syms)
93+
# mpref.varname = var = id === nothing ? var : ls.syms_aliasing_refs[id]
94+
# reduction_ind = ind
95+
# mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
96+
# else
97+
# pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
98+
# end
99+
# end
100100
search_tree_for_ref(ls::LoopSet, opv::Vector{Operation}, ::Nothing, var::Symbol) = var, false
101101
function search_tree_for_ref(ls::LoopSet, opv::Vector{Operation}, mpref::ArrayReferenceMetaPosition, var::Symbol) # relies on cycles being forbidden
102102
for opp opv

0 commit comments

Comments
 (0)