@@ -10,6 +10,40 @@ function (::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V})(p::Ptr{UInt}) where {UNROLL,OPS,A
10
10
ThreadingUtilities. store! (p, ret, 7 )
11
11
nothing
12
12
end
13
+ @generated function Base. pointer (:: AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V} ) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
14
+ f = AVX {UNROLL,OPS,ARF,AM,LPSYM,LB,V} ()
15
+ precompile (f, (Ptr{UInt},))
16
+ quote
17
+ $ (Expr (:meta ,:inline ))
18
+ @cfunction ($ f, Cvoid, (Ptr{UInt},))
19
+ end
20
+ end
21
+
22
+ function launch! (p:: Ptr{UInt} , fptr:: Ptr{Cvoid} , args:: Tuple{LB,V} ) where {LB,V}
23
+ offset = ThreadingUtilities. store! (p, fptr, 0 )
24
+ offset = ThreadingUtilities. store! (p, args, offset)
25
+ nothing
26
+ end
27
+ function launch (
28
+ :: Val{UNROLL} , :: Val{OPS} , :: Val{ARF} , :: Val{AM} , :: Val{LPSYM} , lb:: LB , vargs:: V , tid
29
+ ) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
30
+ p = ThreadingUtilities. taskpointer (tid)
31
+ f = AVX {UNROLL,OPS,ARF,AM,LPSYM,LB,V} ()
32
+ fptr = pointer (f)
33
+ while true
34
+ if ThreadingUtilities. _atomic_cas_cmp! (p, ThreadingUtilities. SPIN, ThreadingUtilities. STUP)
35
+ launch! (p, fptr, (lb,vargs))
36
+ @assert ThreadingUtilities. _atomic_cas_cmp! (p, ThreadingUtilities. STUP, ThreadingUtilities. TASK)
37
+ return
38
+ elseif ThreadingUtilities. _atomic_cas_cmp! (p, ThreadingUtilities. WAIT, ThreadingUtilities. STUP)
39
+ launch! (p, fptr, (lb,vargs))
40
+ @assert ThreadingUtilities. _atomic_cas_cmp! (p, ThreadingUtilities. STUP, ThreadingUtilities. LOCK)
41
+ ThreadingUtilities. wake_thread! (tid % UInt)
42
+ return
43
+ end
44
+ ThreadingUtilities. pause ()
45
+ end
46
+ end
13
47
14
48
# function approx_cbrt(x)
15
49
# s = significand(x)
18
52
# # 40 + 0.00020833333333333335*(x-64000) -2.1701388888888896e-9*(x-64000)^2*0.5 + 5.6514033564814844e-14 * (x-64000)^3/6
19
53
# end
20
54
21
- function choose_threads (:: StaticInt{C} , x) where {C}
55
+ function choose_num_threads (:: StaticInt{C} , x) where {C}
22
56
nt = ifelse (gt (num_threads (), num_cores ()), num_cores (), num_threads ())
23
57
fx = Base. uitofp (Float64, x)
24
- min (Base. fptosi (Int, Base. ceil_llvm (5.0852672001495816e-11 * C* Base. sqrt_llvm (fx))), nt)
58
+ min (Base. fptoui (UInt, Base. ceil_llvm (5.0852672001495816e-11 * C* Base. sqrt_llvm (fx))), UInt (nt))
59
+ end
60
+ function push_loop_length_expr! (q:: Expr , ls:: LoopSet )
61
+ l = 1
62
+ ndynamic = 0
63
+ mulexpr = length (ls. loops) == 1 ? q : Expr (:call , lv (:vmul_fast ))
64
+ for loop ∈ ls. loops
65
+ if isstaticloop (loop)
66
+ l *= length (loop)
67
+ else
68
+ ndynamic += 1
69
+ if ndynamic < 3
70
+ push! (mulexpr. args, loop. lensym)
71
+ else
72
+ mulexpr = Expr (:call , lv (:vmul_fast ), mulexpr, loop. lensym)
73
+ end
74
+ end
75
+ end
76
+ if length (ls. loops) == 1
77
+ ndynamic == 0 && push! (q. args, l)
78
+ elseif l == 1
79
+ push! (q. args, mulexpr)
80
+ elseif ndynamic == 0
81
+ push! (q. args, l)
82
+ elseif ndynamic == 1
83
+ push! (mulexpr. args, l)
84
+ push! (q. args, mulexpr)
85
+ else
86
+ push! (q. args, Expr (:call , :vmul_fast , mulexpr, l))
87
+ end
88
+ nothing
89
+ end
90
+ function divrem_fast (numerator, denominator)
91
+ d = Base. udiv_int (numerator, denominator)
92
+ r = numerator - denominator* d
93
+ d, r
94
+ end
95
+
96
+ function outer_reduct_combine_expressions (ls:: LoopSet , retv)
97
+ q = Expr (:block , :(var"#load#thread#ret#" = ThreadingUtilities. store! (var"#thread#ptr#" , typeof ($ retv), 7 )))
98
+ for (i,or) ∈ enumerate (ls. outer_reductions)
99
+ op = ls. operations[or]
100
+ var = name (op)
101
+ mvar = mangledvar (op)
102
+ instr = instruction (op)
103
+ out = Symbol (mvar, " ##onevec##" )
104
+ instrcall = callexp (instr)
105
+ push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), out))
106
+ if length (ls. outer_reductions) > 1
107
+ push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), Expr (:call , GlobalRef (Core, :getfield ), Symbol (" #load#thread#ret#" ), i, false )))
108
+ else
109
+ push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), Symbol (" #load#thread#ret#" )))
110
+ end
111
+ push! (q. args, Expr (:(= ), out, Expr (:call , :data , instrcall)))
112
+ end
113
+ q
114
+ end
115
+
116
+ function thread_loop_summary! (ls, threadedloop:: Loop , u₁loop:: Loop , u₂loop:: Loop , vloop:: Loop , issecondthreadloop:: Bool )
117
+ threadloopnumtag = Int (issecondthreadloop)
118
+ lensym = Symbol (" #len#thread#$threadloopnumtag #" )
119
+ define_len = if isstaticloop (threadedloop)
120
+ :($ lensym = $ (length (threadedloop)))
121
+ else
122
+ :($ lensym = $ ((threadedloop. lensym)))
123
+ end
124
+ unroll_factor = 1
125
+ if threadedloop === vloop
126
+ unroll_factor *= W
127
+ end
128
+ if threadedloop === u₁loop
129
+ unroll_factor *= u₁
130
+ elseif threadedloop === u₂loop
131
+ unroll_factor *= u₂
132
+ end
133
+ num_unroll_sym = Symbol (" #num#unrolls#thread#$threadloopnumtag #" )
134
+ define_num_unrolls = if unroll_factor == 1
135
+ :($ num_unroll_sym = $ lensym)
136
+ else
137
+ :($ num_unroll_sym = Base. udiv_int ($ lensym, $ (UInt (unroll_factor))))
138
+ end
139
+ iterstart_sym = Symbol (" #iter#start#$threadloopnumtag #" )
140
+ iterstop_sym = Symbol (" #iter#stop#$threadloopnumtag #" )
141
+ blksz_sym = Symbol (" #nblock#size#thread#$threadloopnumtag #" )
142
+ loopstart = if isknown (first (threadedloop))
143
+ :($ iterstart_sym = $ (gethint (first (threadedloop))))
144
+ else
145
+ :($ iterstart_sym = $ (getsym (first (threadedloop))))
146
+ end
147
+ if isknown (step (threadedloop))
148
+ mf = gethint (threadedloop) * unroll_factor
149
+ if isone (mf)
150
+ iterstop = :($ iterstop_sym = $ iterstart_sym + $ blksz_sym)
151
+ looprange = :(CloseOpen ($ iterstart_sym, $ iterstop_sym))
152
+ lastrange = if isknown (last (threadedloop))
153
+ :(CloseOpen ($ iterstart_sym,$ (gethint (threadedloop)+ 1 )))
154
+ else # we want all the intervals to have the same type.
155
+ :(CloseOpen ($ iterstart_sym,$ (getsym (threadedloop))+ 1 ))
156
+ end
157
+ else
158
+ iterstop = :($ iterstop_sym = $ iterstart_sym + $ blksz_sym * $ mf)
159
+ looprange = :($ iterstart_sym: StaticInt {$mf} (): $ iterstop_sym- 1 )
160
+ lastrange = if isknown (last (threadedloop))
161
+ :($ iterstart_sym: StaticInt {$mf} (): $ (gethint (threadedloop)))
162
+ else
163
+ :($ iterstart_sym: StaticInt {$mf} (): $ (getsym (threadedloop)))
164
+ end
165
+ end
166
+ else
167
+ stepthread_sym = Symbol (" #step#thread#$threadloopnumtag #" )
168
+ pushpreamble! (ls, :($ stepthread_sym = $ unroll_factor * $ (getsym (step (threadedloop)))))
169
+ iterstop = :($ iterstop_sym = $ iterstart_sym + $ blksz_sym * $ stepthread_sym)
170
+ looprange = :($ iterstart_sym: $ stepthread_sym: $ iterstop_sym- 1 )
171
+ lastrange = if isknown (last (threadedloop))
172
+ :($ iterstart_sym: $ stepthread_sym: $ (gethint (threadedloop)))
173
+ else
174
+ :($ iterstart_sym: $ stepthread_sym: $ (getsym (threadedloop)))
175
+ end
176
+ end
177
+ define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange
25
178
end
26
179
27
- function thread_single_loop_expr (ls:: LoopSet , UNROLL, id)
180
+ function thread_single_loop_expr (ls:: LoopSet , ua:: UnrollArgs , valid_thread_loop, c, UNROLL, OPS, ARF, AM, LPSYM)
181
+ choose_nthread = :(choose_num_threads (StaticInt {$c} ()))
182
+ push_loop_length_expr! (choose_nthread, ls)
183
+ threadedid = findfirst (valid_thread_loop):: Int
184
+ @unpack u₁loop, u₂loop, vloop, u₁, u₂ = ua
185
+ W = ls. vector_width[]
186
+ threadedloop = getloop (ls, threadedid)
187
+ define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange = thread_loop_summary! (ls, threadedloop, u₁loop, u₂loop, vloop, 0 )
188
+ loopboundexpr = Expr (:tuple )
189
+ lastboundexpr = Expr (:tuple )
190
+ for (i,loop) ∈ enumerate (threadedloop)
191
+ if loop === threadedloop
192
+ push! (loopboundexpr. args, looprange)
193
+ push! (lastboundexpr. args, lastrange)
194
+ else
195
+ loop_boundary! (loopboundexpr, loop)
196
+ loop_boundary! (lastboundexpr, loop)
197
+ end
198
+ end
199
+ _avx_call_ = :(_avx_! (Val {$UNROLL} (), Val {$OPS} (), Val {$ARF} (), Val {$AM} (), Val {$LPSYM} (), $ lastboundexpr, var"#vargs#" ))
200
+ update_return_values = if length (ls. outer_reductions) > 0
201
+ retv = loopset_return_value (ls, Val (false ))
202
+ _avx_call_ = Expr (:(= ), retv, _avx_call_)
203
+ outer_reduct_combine_expressions (ls, retv)
204
+ else
205
+ nothing
206
+ end
207
+ q = quote
208
+ var"#nthreads#" = $ choose_nthread # UInt
209
+ $ define_len % UInt
210
+ $ define_num_unrolls
211
+ var"#nthreads#" = Base. min (var"#nthreads#" , $ num_unrolls)
212
+ var"#nrequest#" = (var"#nthreads#" % UInt32) - 0x00000001
213
+ var"#nrequest#" == 0x00000000 && return LoopVectorization. _avx_! (Val {$UNROLL} (), Val{$ OPS}, Val {$ARF} (), Val {$AM} (), Val {$LPSYM} (), var"#lv#tuple#args#" )
214
+ var"#threads#" , var"#torelease#" = LoopVectorization. _request_threads (Threads. threadid (), var"#nrequest#" )
28
215
216
+ var"#base#block#size#thread#0#" , var"#nrem#thread#" = LoopVectorization. divrem_fast (num_unrolls, var"#nthreads#" )
217
+ $ loopstart
218
+
219
+ var"#thread#launch#count#" = 0x00000000
220
+ var"#thread#id#" = 0x00000000
221
+ var"#thread#mask#" = CheapThreads. mask (var"#threads#" )
222
+ var"#threads#remain#" = true
223
+ while var"#threads#remain#"
224
+ VectorizationBase. assume (var"#thread#mask#" ≠ zero (var"#thread#mask#" ))
225
+ var"#trailzing#zeros#" = Base. trailing_zeros (var"#thread#mask#" ) % UInt32
226
+ var"#thread#launch#count#" += 0x00000001
227
+ var"#nblock#size#thread#0#" = Core. ifelse (
228
+ var"#thread#launch#count#" < (var"#nrem#thread#" % Base. typeof (var"#threadid#" )),
229
+ var"#base#block#size#thread#0#" + Base. one (var"#base#block#size#thread#0#" ),
230
+ var"#base#block#size#thread#0#"
231
+ )
232
+ var"#trailzing#zeros#" += 0x00000001
233
+ $ iterstop
234
+ var"#thread#id#" += var"#trailzing#zeros#"
235
+
236
+ LoopVectorization. launch (
237
+ Val {$UNROLL} (), Val {$OPS} (), Val {$ARF} (), Val {$AM} (), Val {$LPSYM} (),
238
+ $ loopboundexpr, var"#vargs#" , var"#thread#id#"
239
+ )
240
+
241
+ var"#thread#mask#" >>>= var"#trailzing#zeros#"
242
+
243
+ var"#iter#start#0#" = var"#iter#stop#0#"
244
+ var"#threads#remain#" = var"#thread#launch#count#" ≠ var"$nrequest#"
245
+ end
246
+ $ _avx_call_
247
+ var"#thread#id#" = 0x00000000
248
+ var"#thread#mask#" = CheapThreads. mask (var"#threads#" )
249
+ var"#threads#remain#" = true
250
+ while var"#threads#remain#"
251
+ VectorizationBase. assume (var"#thread#mask#" ≠ zero (var"#thread#mask#" ))
252
+ var"#trailzing#zeros#" = Base. trailing_zeros (var"#thread#mask#" ) % UInt32
253
+ var"#trailzing#zeros#" += 0x00000001
254
+ var"#thread#mask#" >>>= var"#trailzing#zeros#"
255
+ var"#thread#id#" += var"#trailzing#zeros#"
256
+ var"#thread#ptr#" = ThreadingUtilities. taskpointer (var"#thread#id#" )
257
+ ThreadingUtilities. __wait (var"#thread#ptr#" )
258
+ $ update_return_values
259
+ var"#threads#remain#" = var"#thread#mask#" ≠ 0x00000000
260
+ end
261
+ CheapThreads. free_threads! (var"#torelease#" )
262
+ end
263
+ length (ls. outer_reductions) > 0 ? push! (q. args, retv) : push! (q. args, nothing )
264
+ q
29
265
end
30
266
function thread_multiple_loop_expr (ls:: LoopSet , UNROLL, valid_thread_loop)
31
267
32
268
end
33
269
34
- function avx_threads_expr (ls:: LoopSet , UNROLL )
270
+ function valid_thread_loops (ls:: LoopSet )
35
271
order, u₁loop, u₂loop, vectorized, u₁, u₂, c, shouldinline = choose_order_cost (ls)
272
+ # NOTE: `names` are being placed in the opposite order here versus normal lowering!
273
+ copyto! (names (ls), order); init_loop_map! (ls)
274
+ ua = UnrollArgs (getloop (ls, u₁loop), getloop (ls, u₂loop), getloop (ls, vloop), u₁, u₂, u₂)
36
275
valid_thread_loop = fill (true , length (order))
37
276
for op ∈ operations (ls)
38
277
if isstore (op) && (length (reduceddependencies (op)) > 0 )
@@ -45,6 +284,10 @@ function avx_threads_expr(ls::LoopSet, UNROLL)
45
284
end
46
285
end
47
286
end
287
+ valid_thread_loop, ua, c
288
+ end
289
+ function avx_threads_expr (ls:: LoopSet , UNROLL)
290
+ valid_thread_loop, us, c = valid_thread_loops (ls)
48
291
num_candiates = sum (valid_thread_loop)
49
292
# num_to_thread = min(num_candiates, 2)
50
293
# candidate_ids =
@@ -54,8 +297,7 @@ function avx_threads_expr(ls::LoopSet, UNROLL)
54
297
thread_single_loop_expr (ls, UNROLL, findfirst (isone, valid_thread_loop):: Int )
55
298
else
56
299
thread_multiple_loop_expr (ls, UNROLL, vald_thread_loop)
57
- end
58
-
300
+ end
59
301
end
60
302
61
303
0 commit comments