@@ -167,8 +167,6 @@ function lower_load!(
167
167
q:: Expr , op:: Operation , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , U:: Int ,
168
168
suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
169
169
)
170
- # @show op.instruction
171
- # @show unrolled, loopdependencies(op)
172
170
if vectorized ∈ loopdependencies (op)
173
171
lower_load_vectorized! (q, op, vectorized, W, unrolled, U, suffix, mask)
174
172
else
@@ -205,7 +203,6 @@ function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
205
203
Uh = Uh2 >> 1
206
204
reduce_range! (q, toreduct, instr, Uh, Uh2)
207
205
Uh == 1 && break
208
- # @show Uh
209
206
Uh2 = Uh
210
207
iter += 1 ; iter > 4 && throw (" Oops! This seems to be excessive unrolling." )
211
208
end
@@ -220,11 +217,10 @@ pvariable_name(op::Operation, suffix) = Symbol(pvariable_name(op, nothing), suff
220
217
function reduce_unroll! (q, op, U, unrolled)
221
218
loopdeps = loopdependencies (op)
222
219
isunrolled = unrolled ∈ loopdeps
223
- if unrolled ∉ reduceddependencies (op)
220
+ if ( unrolled ∉ reduceddependencies (op) )
224
221
U = isunrolled ? U : 1
225
222
return U, isunrolled
226
223
end
227
- unrolled ∈ reduceddependencies (op) || return U
228
224
var = mangledvar (op)
229
225
instr = first (parents (op)). instruction
230
226
reduce_expr! (q, var, instr, U) # assigns reduction to storevar
@@ -278,7 +274,6 @@ function lower_store_vectorized!(
278
274
for u ∈ 0 : U- 1
279
275
name, mo = name_mo (var, op, u, W, vecnotunrolled, unrolled)
280
276
instrcall = Expr (:call ,lv (:vstore! ), ptr, name, mo)
281
- # @show mask, vecnotunrolled, u, U
282
277
if mask != = nothing && (vecnotunrolled || u == U - 1 )
283
278
push! (instrcall. args, mask)
284
279
end
@@ -365,7 +360,6 @@ function lower_compute!(
365
360
# cache unroll and tiling check of parents
366
361
# not broadcasted, because we use frequent checks of individual bools
367
362
# making BitArrays inefficient.
368
- # @show instr parentsunrolled
369
363
# parentsyms = [opp.variable for opp ∈ parents(op)]
370
364
Uiter = opunrolled ? U - 1 : 0
371
365
maskreduct = mask != = nothing && isreduction (op) && any (opp -> opp. variable === var, parents_op)
@@ -401,27 +395,13 @@ function lower_compute!(
401
395
end
402
396
push! (instrcall. args, parent)
403
397
end
404
- if maskreduct && u == Uiter # only mask last
398
+ if maskreduct && ( u == Uiter || unrolled != = vectorized) # only mask last
405
399
push! (q. args, Expr (:(= ), varsym, Expr (:call , lv (:vifelse ), mask, instrcall, varsym)))
406
400
else
407
401
push! (q. args, Expr (:(= ), varsym, instrcall))
408
402
end
409
403
end
410
404
end
411
- function lower! (
412
- q:: Expr , op:: Operation , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , tiled:: Symbol , U:: Int ,
413
- suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
414
- )
415
- if isload (op)
416
- lower_load! (q, op, vectorized, W, unrolled, U, suffix, mask)
417
- elseif isstore (op)
418
- lower_store! (q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
419
- elseif iscompute (op)
420
- lower_compute! (q, op, vectorized, W, unrolled, U, suffix, mask)
421
- else
422
- lower_constant! (q, op, vectorized, W, unrolled, U, suffix, mask)
423
- end
424
- end
425
405
function lower_constant! (
426
406
q:: Expr , op:: Operation , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , U:: Int ,
427
407
suffix:: Union{Nothing,Int} , mask:: Any = nothing
@@ -470,9 +450,23 @@ function lower_constant!(
470
450
q:: Expr , ops:: AbstractVector{Operation} , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , U:: Int ,
471
451
suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
472
452
)
473
- foreach (op -> lower_constan ! (q, op, vectorized, W, unrolled, U, suffix, mask), ops)
453
+ foreach (op -> lower_constant ! (q, op, vectorized, W, unrolled, U, suffix, mask), ops)
474
454
end
475
455
456
+ function lower! (
457
+ q:: Expr , op:: Operation , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , tiled:: Symbol , U:: Int ,
458
+ suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
459
+ )
460
+ if isconstant (op)
461
+ lower_constant! (q, op, vectorized, W, unrolled, U, suffix, mask)
462
+ elseif isload (op)
463
+ lower_load! (q, op, vectorized, W, unrolled, U, suffix, mask)
464
+ elseif iscompute (op)
465
+ lower_compute! (q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
466
+ else # if isstore(op)
467
+ lower_store! (q, op, vectorized, W, unrolled, U, suffix, mask)
468
+ end
469
+ end
476
470
function lower! (
477
471
q:: Expr , ops:: AbstractVector{<:AbstractVector{Operation}} , vectorized:: Symbol , W:: Symbol , unrolled:: Symbol , tiled:: Symbol , U:: Int ,
478
472
suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
@@ -501,7 +495,6 @@ function lower_nest(
501
495
nisvectorized = loopsym === vectorized
502
496
nisunrolled = false
503
497
nistiled = false
504
- # @show n, mask
505
498
if istiled
506
499
if n == nloops
507
500
loopsym = tiledsym (loopsym)
@@ -514,7 +507,6 @@ function lower_nest(
514
507
unrolled = last (order)
515
508
nisunrolled = n == nloops
516
509
end
517
- # @show unrolled, order
518
510
blockq = Expr (:block )
519
511
n == 1 || push! (blockq. args, Expr (:(= ), order[n- 1 ], loopstart))
520
512
loopq = if exprtype === :block
@@ -581,7 +573,6 @@ function add_vec_rem_iter(
581
573
loopq
582
574
end
583
575
function lower_set (ls:: LoopSet , vectorized:: Symbol , U:: Int , T:: Int , W:: Symbol , :: Nothing , Uexprtype:: Symbol )
584
- # @show U, T, W
585
576
loopstart = 0
586
577
istiled = T != - 1
587
578
order = names (ls)
@@ -620,11 +611,11 @@ function lower_set_unrolled_is_vectorized(ls::LoopSet, vectorized::Symbol, U::In
620
611
loopq
621
612
end
622
613
function initialize_outer_reductions! (
623
- q:: Expr , op:: Operation , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , unrolled :: Symbol , suffix:: Union{Symbol,Nothing} = nothing
614
+ q:: Expr , op:: Operation , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , vectorized :: Symbol , suffix:: Union{Symbol,Nothing} = nothing
624
615
)
625
616
# T = op.elementbytes == 8 ? :Float64 : :Float32
626
617
z = Expr (:call , REDUCTION_ZERO[op. instruction], typeT)
627
- if unrolled ∈ reduceddependencies (op)
618
+ if vectorized ∈ reduceddependencies (op)
628
619
z = Expr (:call , lv (:vbroadcast ), W, z)
629
620
end
630
621
mvar = variable_name (op, suffix)
@@ -633,16 +624,15 @@ function initialize_outer_reductions!(
633
624
end
634
625
nothing
635
626
end
636
- function initialize_outer_reductions! (q:: Expr , ls:: LoopSet , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , unrolled :: Symbol , suffix:: Union{Symbol,Nothing} = nothing )
637
- foreach (or -> initialize_outer_reductions! (q, ls. operations[or], Umin, Umax, W, typeT, unrolled , suffix), ls. outer_reductions)
627
+ function initialize_outer_reductions! (q:: Expr , ls:: LoopSet , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , vectorized :: Symbol , suffix:: Union{Symbol,Nothing} = nothing )
628
+ foreach (or -> initialize_outer_reductions! (q, ls. operations[or], Umin, Umax, W, typeT, vectorized , suffix), ls. outer_reductions)
638
629
end
639
- function initialize_outer_reductions! (ls:: LoopSet , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , unrolled :: Symbol , suffix:: Union{Symbol,Nothing} = nothing )
640
- initialize_outer_reductions! (ls. preamble, ls, Umin, Umax, W, typeT, unrolled , suffix)
630
+ function initialize_outer_reductions! (ls:: LoopSet , Umin:: Int , Umax:: Int , W:: Symbol , typeT:: Symbol , vectorized :: Symbol , suffix:: Union{Symbol,Nothing} = nothing )
631
+ initialize_outer_reductions! (ls. preamble, ls, Umin, Umax, W, typeT, vectorized , suffix)
641
632
end
642
- function add_upper_outer_reductions (ls:: LoopSet , loopq:: Expr , Ulow:: Int , Uhigh:: Int , W:: Symbol , typeT:: Symbol , unrolledloop:: Loop )
643
- unrolled = unrolledloop. itersymbol
633
+ function add_upper_outer_reductions (ls:: LoopSet , loopq:: Expr , Ulow:: Int , Uhigh:: Int , W:: Symbol , typeT:: Symbol , unrolledloop:: Loop , vectorized:: Symbol )
644
634
ifq = Expr (:block )
645
- initialize_outer_reductions! (ifq, ls, Ulow, Uhigh, W, typeT, unrolled )
635
+ initialize_outer_reductions! (ifq, ls, Ulow, Uhigh, W, typeT, vectorized )
646
636
push! (ifq. args, loopq)
647
637
reduce_range! (ifq, ls, Ulow, Uhigh)
648
638
comparison = if unrolledloop. hintexact
@@ -786,7 +776,7 @@ function lower_unrolled_dynamic!(
786
776
if manageouterreductions
787
777
# Umax = (!static_unroll && U > 2) ? U >> 1 : U
788
778
Ureduct = U > 6 ? 4 : U
789
- initialize_outer_reductions! (q, ls, 0 , Ureduct, W, typeT, last (names (ls)))
779
+ initialize_outer_reductions! (q, ls, 0 , Ureduct, W, typeT, vectorized) # last(names(ls)))
790
780
else
791
781
Ureduct = - 1
792
782
end
@@ -798,7 +788,7 @@ function lower_unrolled_dynamic!(
798
788
if firstiter # first iter
799
789
loopq = lower_set (ls, vectorized, Ut, T, W, nothing , Uexprtype)
800
790
if T == - 1 && manageouterreductions && U > 4
801
- loopq = add_upper_outer_reductions (ls, loopq, Ureduct, U, W, typeT, unrolledloop)
791
+ loopq = add_upper_outer_reductions (ls, loopq, Ureduct, U, W, typeT, unrolledloop, vectorized )
802
792
end
803
793
push! (q. args, loopq)
804
794
elseif U == 1 #
@@ -863,10 +853,11 @@ function definemask(loop::Loop, W::Symbol, allon::Bool)
863
853
maskexpr (W, loop. rangesym, allon)
864
854
end
865
855
end
866
- function setup_Wmask! (ls:: LoopSet , W:: Symbol , typeT:: Symbol , vectorized:: Symbol , unrolled:: Symbol , U:: Int )
856
+ function setup_Wmask! (ls:: LoopSet , W:: Symbol , typeT:: Symbol , vectorized:: Symbol , unrolled:: Symbol , tiled :: Symbol , U:: Int )
867
857
pushpreamble! (ls, Expr (:(= ), typeT, determine_eltype (ls)))
868
858
pushpreamble! (ls, Expr (:(= ), W, determine_width (ls, typeT, unrolled)))
869
859
pushpreamble! (ls, definemask (ls. loops[vectorized], W, U > 1 && unrolled === vectorized))
860
+ # define_remaining_ops!( ls, vectorized, W, unrolled, tiled, U )
870
861
end
871
862
function lower_tiled (ls:: LoopSet , vectorized:: Symbol , U:: Int , T:: Int )
872
863
order = ls. loop_order. loopnames
@@ -875,11 +866,11 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
875
866
mangledtiled = tiledsym (tiled)
876
867
W = gensym (:W )
877
868
typeT = gensym (:T )
878
- setup_Wmask! (ls, W, typeT, vectorized, unrolled, U)
869
+ setup_Wmask! (ls, W, typeT, vectorized, unrolled, tiled, U)
879
870
tiledloop = ls. loops[tiled]
880
871
static_tile = tiledloop. hintexact
881
872
unrolledloop = ls. loops[unrolled]
882
- initialize_outer_reductions! (ls, 0 , 4 , W, typeT, unrolled)
873
+ initialize_outer_reductions! (ls, 0 , 4 , W, typeT, vectorized) # unrolled)
883
874
q = Expr (:block , Expr (:(= ), mangledtiled, 0 ))
884
875
# we build up the loop expression.
885
876
Trem = Tt = T
@@ -932,12 +923,11 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
932
923
end
933
924
function lower_unrolled (ls:: LoopSet , vectorized:: Symbol , U:: Int )
934
925
order = ls. loop_order. loopnames
935
- # @show order
936
926
unrolled = last (order)
937
927
# W = VectorizationBase.pick_vector_width(ls, unrolled)
938
928
W = gensym (:W )
939
929
typeT = gensym (:T )
940
- setup_Wmask! (ls, W, typeT, vectorized, unrolled, U)
930
+ setup_Wmask! (ls, W, typeT, vectorized, unrolled, last (order), U)
941
931
q = lower_unrolled! (Expr (:block , Expr (:(= ), unrolled, 0 )), ls, vectorized, U, - 1 , W, typeT, ls. loops[unrolled])
942
932
Expr (:block , ls. preamble, q)
943
933
end
950
940
# Requires sorting
951
941
function lower (ls:: LoopSet )
952
942
order, vectorized, U, T = choose_order (ls)
953
- # @show order, U, T
954
- # @show ls.loop_order.loopnames
955
943
istiled = T != - 1
956
944
fillorder! (ls, order, istiled)
957
- # @show order, ls.loop_order.loopnames
958
945
istiled ? lower_tiled (ls, vectorized, U, T) : lower_unrolled (ls, vectorized, U)
959
946
end
960
947
0 commit comments