@@ -185,6 +185,7 @@ function add_mref!(
185
185
) where {T}
186
186
@assert B ≤ 0 " Batched arrays not supported yet."
187
187
_add_mref! (sptrs, ls, ar, typetosym (T), C, B, sp, name)
188
+ sizeof (T)
188
189
end
189
190
typetosym (:: Type{T} ) where {T<: NativeTypes } = (VectorizationBase. JULIA_TYPES[T]):: Symbol
190
191
typetosym (T) = T
@@ -239,14 +240,15 @@ function add_mref!(
239
240
sptrs:: Expr , :: LoopSet , :: ArrayReferenceMeta , @nospecialize (_:: Type{VectorizationBase.FastRange{T,F,S,O}} ),
240
241
:: Int , :: Int , sp:: Vector{Int} , name:: Symbol
241
242
) where {T,F,S,O}
242
- extract_gsp! (sptrs, name)
243
+ extract_gsp! (sptrs, name)
244
+ sizeof (T)
243
245
end
244
246
function create_mrefs! (
245
247
ls:: LoopSet , arf:: Vector{ArrayRefStruct} , as:: Vector{Symbol} , os:: Vector{Symbol} ,
246
248
nopsv:: Vector{NOpsType} , expanded:: Vector{Bool} , :: Type{Tuple{}}
247
249
)
248
250
length (arf) == 0 || throw (ArgumentError (" Length of array ref vector should be 0 if there are no stridedpointers." ))
249
- Vector {ArrayReferenceMeta} (undef, length (arf))
251
+ Vector {ArrayReferenceMeta} (undef, length (arf)), Int[]
250
252
end
251
253
function stabilize_grouped_stridedpointer_type (C, B, R)
252
254
N = (length (C)):: Int
@@ -271,12 +273,12 @@ function create_mrefs!(
271
273
Cv,Bv,Rv = stabilize_grouped_stridedpointer_type (C, B, R)
272
274
_create_mrefs! (ls, arf, as, os, nopsv, expanded, P. parameters, Cv, Bv, Rv)
273
275
end
274
-
275
276
function _create_mrefs! (
276
277
ls:: LoopSet , arf:: Vector{ArrayRefStruct} , as:: Vector{Symbol} , os:: Vector{Symbol} ,
277
278
nopsv:: Vector{NOpsType} , expanded:: Vector{Bool} , P:: Core.SimpleVector , C:: Vector{Int} , B:: Vector{Int} , R:: Vector{Tuple{NTuple{8,Int},Int}}
278
279
)
279
280
mrefs:: Vector{ArrayReferenceMeta} = Vector {ArrayReferenceMeta} (undef, length (arf))
281
+ elementbytes:: Vector{Int} = Vector {Int} (undef, length (arf))
280
282
sptrs = Expr (:tuple )
281
283
# pushpreamble!(ls, Expr(:(=), sptrs, :(VectorizationBase.stridedpointers(getfield(vargs, 1, false)))))
282
284
pushpreamble! (ls, Expr (:(= ), sptrs, :(VectorizationBase. stridedpointers (getfield (var"#vargs#" , 1 , false )))))
@@ -292,6 +294,7 @@ function _create_mrefs!(
292
294
# if isassigned(rank_to_sps, k)
293
295
Cₖ, sp = rank_to_sps[k]
294
296
permute_mref! (ar, Cₖ, sp)
297
+ elementbytes[i] = elementbytes[k]
295
298
# end
296
299
break
297
300
end
@@ -300,11 +303,11 @@ function _create_mrefs!(
300
303
j += 1
301
304
sp = rank_to_sortperm (R[j]):: Vector{Int}
302
305
rank_to_sps[i] = (C[j],sp)
303
- add_mref! (sptrs, ls, ar, P[j], C[j], B[j], sp, vptr (ar))
306
+ elementbytes[i] = add_mref! (sptrs, ls, ar, P[j], C[j], B[j], sp, vptr (ar))
304
307
end
305
308
mrefs[i] = ar
306
309
end
307
- mrefs
310
+ mrefs, elementbytes
308
311
end
309
312
310
313
function num_parameters (AM)
@@ -408,11 +411,19 @@ function isexpanded(ls::LoopSet, ops::Vector{OperationStruct}, nopsv::Vector{NOp
408
411
false
409
412
end
410
413
end
414
+ function mref_elbytes (os:: OperationStruct , mrefs:: Vector{ArrayReferenceMeta} , elementbytes:: Vector{Int} )
415
+ if isload (os) | isstore (os)
416
+ mrefs[os. array], elementbytes[os. array]
417
+ else
418
+ NOTAREFERENCE, 4
419
+ end
420
+ end
411
421
function add_op! (
412
422
ls:: LoopSet , instr:: Instruction , ops:: Vector{OperationStruct} , nopsv:: Vector{NOpsType} , expandedv:: Vector{Bool} , i:: Int ,
413
- mrefs:: Vector{ArrayReferenceMeta} , opsymbol, elementbytes:: Int
423
+ mrefs:: Vector{ArrayReferenceMeta} , opsymbol, elementbytes:: Vector{ Int}
414
424
)
415
425
os = ops[i]
426
+ mref, elbytes = mref_elbytes (os, mrefs, elementbytes)
416
427
# opsymbol = (isconstant(os) && instr != LOOPCONSTANT) ? instr.instr : opsymbol
417
428
# If it's a CartesianIndex add or subtract, we may have to add multiple operations
418
429
expanded = expandedv[i]# isexpanded(ls, ops, nopsv, i)
@@ -421,10 +432,9 @@ function add_op!(
421
432
optyp = optype (os)
422
433
if ! expanded
423
434
op = Operation (
424
- length (operations (ls)), opsymbol, elementbytes , instr,
435
+ length (operations (ls)), opsymbol, elbytes , instr,
425
436
optyp, loopdependencies (ls, os, true ), reduceddependencies (ls, os, true ),
426
- Operation[], (isload (os) | isstore (os)) ? mrefs[os. array] : NOTAREFERENCE,
427
- childdependencies (ls, os, true )
437
+ Operation[], mref, childdependencies (ls, os, true )
428
438
)
429
439
push! (ls. operations, op)
430
440
push! (opoffsets, opoffsets[end ] + 1 )
@@ -435,10 +445,9 @@ function add_op!(
435
445
for offset = 0 : nops- 1
436
446
sym = nops === 1 ? opsymbol : expandedopname (opsymbol, offset)
437
447
op = Operation (
438
- length (operations (ls)), sym, elementbytes, instr,
439
- optyp, loopdependencies (ls, os, false , offset), reduceddependencies (ls, os, false , offset),
440
- Operation[], (isload (os) | isstore (os)) ? mrefs[os. array] : NOTAREFERENCE,
441
- childdependencies (ls, os, false , offset)
448
+ length (operations (ls)), sym, elbytes, instr, optyp,
449
+ loopdependencies (ls, os, false , offset), reduceddependencies (ls, os, false , offset),
450
+ Operation[], mref, childdependencies (ls, os, false , offset)
442
451
)
443
452
push! (ls. operations, op)
444
453
end
@@ -491,8 +500,8 @@ function add_parents_to_ops!(ls::LoopSet, ops::Vector{OperationStruct}, constoff
491
500
constoffset
492
501
end
493
502
function add_ops! (
494
- ls:: LoopSet , instr:: Vector{Instruction} , ops:: Vector{OperationStruct} , mrefs:: Vector{ArrayReferenceMeta} ,
495
- opsymbols:: Vector{Symbol} , constoffset:: Int , nopsv:: Vector{NOpsType} , expandedv:: Vector{Bool} , elementbytes :: Int
503
+ ls:: LoopSet , instr:: Vector{Instruction} , ops:: Vector{OperationStruct} , mrefs:: Vector{ArrayReferenceMeta} , elementbytes :: Vector{Int} ,
504
+ opsymbols:: Vector{Symbol} , constoffset:: Int , nopsv:: Vector{NOpsType} , expandedv:: Vector{Bool}
496
505
)
497
506
# @show ls.loopsymbols ls.loopsymbol_offsets
498
507
for i ∈ eachindex (ops)
@@ -584,12 +593,6 @@ function avx_loopset!(
584
593
ls:: LoopSet , instr:: Vector{Instruction} , ops:: Vector{OperationStruct} , arf:: Vector{ArrayRefStruct} ,
585
594
AM:: Vector{Any} , LPSYM:: Vector{Any} , LB:: Core.SimpleVector , vargs:: Core.SimpleVector
586
595
)
587
- # TODO : check outer reduction types instead
588
- elementbytes = if length (vargs[1 ]. parameters) > 0
589
- sizeofeltypes (vargs[1 ]. parameters[1 ]. parameters)
590
- else
591
- 8
592
- end
593
596
pushpreamble! (ls, :((var"#loop#bounds#" , var"#vargs#" ) = var"#lv#tuple#args#" ))
594
597
add_loops! (ls, LPSYM, LB)
595
598
resize! (ls. loop_order, ls. loopsymbol_offsets[end ])
@@ -599,12 +602,12 @@ function avx_loopset!(
599
602
expandedv = [isexpanded (ls, ops, nopsv, i) for i ∈ eachindex (ops)]
600
603
601
604
resize! (ls. loopindexesbit, length (ls. loops)); fill! (ls. loopindexesbit, false );
602
- mrefs = create_mrefs! (ls, arf, arraysymbolinds, opsymbols, nopsv, expandedv, vargs[1 ])
605
+ mrefs, elementbytes = create_mrefs! (ls, arf, arraysymbolinds, opsymbols, nopsv, expandedv, vargs[1 ])
603
606
for mref ∈ mrefs
604
607
push! (ls. includedactualarrays, vptr (mref))
605
608
end
606
609
# extra args extraction
607
- extractind = add_ops! (ls, instr, ops, mrefs, opsymbols, 1 , nopsv, expandedv, elementbytes )
610
+ extractind = add_ops! (ls, instr, ops, mrefs, elementbytes, opsymbols, 1 , nopsv, expandedv)
608
611
extractind = process_metadata! (ls, AM, extractind)
609
612
extractind = add_array_symbols! (ls, arraysymbolinds, extractind)
610
613
extractind = extract_external_functions! (ls, extractind, vargs)
@@ -645,12 +648,11 @@ function _avx_loopset(
645
648
ls = LoopSet (:LoopVectorization )
646
649
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
647
650
set_hw! (ls, rs, rc, cls, l1, l2, l3); ls. vector_width = W; ls. isbroadcast = isbroadcast
648
- avx_loopset! (
649
- ls, instr, ops,
650
- ArrayRefStruct[ARFsv... ],
651
- tovector (AMsv), tovector (LPSYMsv), LBsv, vargs
652
- ):: LoopSet
653
- ls
651
+ arsv = Vector {ArrayRefStruct} (undef, length (ARFsv))
652
+ for i ∈ eachindex (arsv)
653
+ arsv[i] = ARFsv[i]
654
+ end
655
+ avx_loopset! (ls, instr, ops, arsv, tovector (AMsv), tovector (LPSYMsv), LBsv, vargs)
654
656
end
655
657
656
658
@static if VERSION ≥ v " 1.7.0-DEV.421"
0 commit comments