185
185
@inline array_wrapper (A:: Adjoint ) = Adjoint
186
186
@inline array_wrapper (A:: SubArray ) = A. indices
187
187
188
+
189
+
190
+ @generated function __avx__! (
191
+ :: Val{UT} , :: Type{OPS} , :: Type{ARF} , :: Type{AM} , lb:: LB ,
192
+ :: Val{AR} , :: Val{D} , :: Val{IND} , subsetvals, arraydescript, vargs:: Vararg{<:Any,N}
193
+ ) where {UT, OPS, ARF, AM, LB, N, AR, D, IND}
194
+ num_vptrs = length (ARF. parameters):: Int
195
+ vptrs = [gensym (:vptr ) for _ ∈ 1 : num_vptrs]
196
+ call = Expr (:call , lv (:_avx_! ), Val {UT} (), OPS, ARF, AM, :lb )
197
+ for n ∈ 1 : num_vptrs
198
+ push! (call. args, vptrs[n])
199
+ end
200
+ q = Expr (:block )
201
+ j = 0
202
+ assigned_names = Vector {Symbol} (undef, length (AR))
203
+ num_arrays = 0
204
+ for i ∈ eachindex (AR)
205
+ ari = (AR[i]):: Int
206
+ ind = (IND[i]):: Union{Nothing,Int}
207
+ LHS = ind === nothing ? gensym () : vptrs[ind]
208
+ assigned_names[i] = LHS
209
+ d = (D[i]):: Union{Nothing,Int}
210
+ if d === nothing # stridedpointer
211
+ if ari == - 1
212
+ RHS = Expr (:call , :LoopValue )
213
+ else
214
+ num_arrays += 1
215
+ RHS = Expr (:call , lv (:stridedpointer ), Expr (:ref , :vargs , ari), Expr (:ref , :arraydescript , ari))
216
+ end
217
+ else # subsetview
218
+ j += 1
219
+ RHS = Expr (:call , :subsetview , assigned_names[ari], Expr (:call , Expr (:curly , :Val , d)), Expr (:ref , :subsetvals , j))
220
+ end
221
+ push! (q. args, Expr (:(= ), LHS, RHS))
222
+ end
223
+ for n ∈ num_arrays+ 1 : N
224
+ push! (call. args, Expr (:ref , :vargs , n))
225
+ end
226
+ push! (q. args, call)
227
+ Expr (:macrocall , Symbol (" @inbounds" ), LineNumberNode (@__LINE__ , @__FILE__ ), q)
228
+ end
229
+
188
230
# Try to condense in type stable manner
189
231
function generate_call (ls:: LoopSet , IUT)
190
232
operation_descriptions = Expr (:curly , :Tuple )
@@ -200,20 +242,77 @@ function generate_call(ls::LoopSet, IUT)
200
242
foreach (ref -> push! (arrayref_descriptions. args, ArrayRefStruct (ls, ref, arraysymbolinds)), ls. refs_aliasing_syms)
201
243
argmeta = argmeta_and_consts_description (ls, arraysymbolinds)
202
244
loop_bounds = loop_boundaries (ls)
203
-
204
- q = Expr (:call , lv (:_avx_! ), Expr (:call , Expr (:curly , :Val , IUT)), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
205
- foreach (ref -> push! (q. args, vptr (ref)), ls. refs_aliasing_syms)
245
+ inline, U, T = IUT
246
+ if inline
247
+ q = Expr (:call , lv (:_avx_! ), Expr (:call , Expr (:curly , :Val , (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
248
+ foreach (ref -> push! (q. args, vptr (ref)), ls. refs_aliasing_syms)
249
+ else
250
+ arraydescript = Expr (:tuple )
251
+ q = Expr (:call , lv (:__avx__! ), Expr (:call , Expr (:curly , :Val , (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds, arraydescript)
252
+ for array ∈ ls. includedactualarrays
253
+ push! (q. args, Expr (:call , lv (:unwrap_array ), array))
254
+ push! (arraydescript. args, Expr (:call , lv (:array_wrapper ), array))
255
+ end
256
+ end
206
257
foreach (is -> push! (q. args, last (is)), ls. preamble_symsym)
207
258
append! (q. args, arraysymbolinds)
208
259
add_reassigned_syms! (q, ls)
209
260
add_external_functions! (q, ls)
210
261
q
211
262
end
212
263
213
- function setup_call_noinline (ls:: LoopSet , inline = Int8 ( 2 ), U = zero (Int8), T = zero (Int8))
214
- call = generate_call (ls, (inline ,U,T))
264
+ function setup_call_noinline (ls:: LoopSet , U = zero (Int8), T = zero (Int8))
265
+ call = generate_call (ls, (false ,U,T))
215
266
hasouterreductions = length (ls. outer_reductions) > 0
216
- q = ls. preamble
267
+ q = Expr (:block )
268
+ vptrarrays = Expr (:tuple )
269
+ vptrsubsetvals = Expr (:tuple )
270
+ vptrsubsetdims = Expr (:tuple )
271
+ vptrindices = Expr (:tuple )
272
+ stridedpointerLHS = Symbol[]
273
+ loopvalueLHS = Symbol[]
274
+ for ex ∈ ls. preamble. args
275
+ # vptrcalls = Expr(:tuple)
276
+ if ex isa Expr && ex. head === :(= ) && length (ex. args) == 2
277
+ if ex. args[2 ] isa Expr && ex. args[2 ]. head === :call
278
+ gr = first (ex. args[2 ]. args)
279
+ if gr == lv (:stridedpointer )
280
+ array = ex. args[2 ]. args[2 ]
281
+ arrayid = findfirst (a -> a === array, ls. includedactualarrays)
282
+ if arrayid isa Int
283
+ push! (vptrarrays. args, arrayid)
284
+ else
285
+ @assert array ∈ loopvalueLHS
286
+ push! (vptrarrays. args, - 1 )
287
+ end
288
+ push! (vptrsubsetdims. args, nothing )
289
+ vp = first (ex. args):: Symbol
290
+ push! (stridedpointerLHS, vp)
291
+ push! (vptrindices. args, findfirst (a -> vptr (a) == vp, ls. refs_aliasing_syms))
292
+ elseif gr == lv (:subsetview )
293
+ array = ex. args[2 ]. args[2 ]
294
+ vptrarrayid = findfirst (a -> a === array, stridedpointerLHS)# ::Int
295
+ if vptrarrayid === nothing
296
+ @show array, stridedpointerLHS
297
+ @assert vptrarrayid isa Int
298
+ end
299
+ push! (vptrarrays. args, vptrarrayid:: Int )
300
+ push! (vptrsubsetdims. args, ex. args[2 ]. args[3 ]. args[1 ]. args[2 ])
301
+ push! (vptrsubsetvals. args, ex. args[2 ]. args[4 ])
302
+ vp = first (ex. args):: Symbol
303
+ push! (stridedpointerLHS, vp)
304
+ push! (vptrindices. args, findfirst (a -> vptr (a) == vp, ls. refs_aliasing_syms))
305
+ end
306
+ elseif ex. args[2 ] == LoopValue ()
307
+ push! (loopvalueLHS, first (ex. args))
308
+ end
309
+ end
310
+ push! (q. args, ex)
311
+ end
312
+ insert! (call. args, 7 , Expr (:call , Expr (:curly , :Val , vptrarrays)))
313
+ insert! (call. args, 8 , Expr (:call , Expr (:curly , :Val , vptrsubsetdims)))
314
+ insert! (call. args, 9 , Expr (:call , Expr (:curly , :Val , vptrindices)))
315
+ insert! (call. args, 10 , vptrsubsetvals)
217
316
if hasouterreductions
218
317
outer_reducts = Expr (:local )
219
318
for or ∈ ls. outer_reductions
@@ -227,8 +326,6 @@ function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T =
227
326
retv = loopset_return_value (ls, Val (false ))
228
327
call = Expr (:(= ), retv, call)
229
328
push! (q. args, gc_preserve (ls, call))
230
- push! (q. args, Expr (:return , retv))
231
- q = Expr (:block , Expr (:(= ), retv, Expr (:call , Expr (:(-> ), Expr (:tuple , ls. includedactualarrays... ), q), ls. includedactualarrays... )))
232
329
for or ∈ ls. outer_reductions
233
330
op = ls. operations[or]
234
331
var = name (op)
@@ -239,13 +336,11 @@ function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T =
239
336
end
240
337
else
241
338
push! (q. args, gc_preserve (ls, call))
242
- push! (q. args, Expr (:return , :nothing ))
243
- q = Expr (:call , Expr (:(-> ), Expr (:tuple , ls. includedactualarrays... ), q), ls. includedactualarrays... )
244
339
end
245
340
q
246
341
end
247
- function setup_call_inline (ls:: LoopSet , inline = Int8 ( 2 ), U = zero (Int8), T = zero (Int8))
248
- call = generate_call (ls, (inline ,U,T))
342
+ function setup_call_inline (ls:: LoopSet , U = zero (Int8), T = zero (Int8))
343
+ call = generate_call (ls, (true ,U,T))
249
344
hasouterreductions = length (ls. outer_reductions) > 0
250
345
if hasouterreductions
251
346
retv = loopset_return_value (ls, Val (false ))
@@ -273,9 +368,9 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
273
368
# Creating an anonymous function and calling it also achieves the outlining, while still
274
369
# inlining the generated function into the loop preamble.
275
370
if inline == Int8 (2 )
276
- setup_call_inline (ls, Int8 ( 2 ), U, T)
371
+ setup_call_inline (ls, U, T)
277
372
else
278
- setup_call_noinline (ls, Int8 ( 2 ), U, T)
373
+ setup_call_noinline (ls, U, T)
279
374
end
280
375
end
281
376
0 commit comments