1
1
2
2
3
- function load_constrained (op, u₁loop, u₂loop, innermost_loop_or_vloop, forprefetch = false )
3
+ function load_constrained (op:: Operation , u₁loop:: Symbol , u₂loop:: Symbol , innermost_loop_or_vloop:: Symbol , forprefetch:: Bool = false )
4
4
dependsonu₁ = isu₁unrolled (op)
5
5
dependsonu₂ = isu₂unrolled (op)
6
6
if forprefetch
@@ -21,8 +21,8 @@ function load_constrained(op, u₁loop, u₂loop, innermost_loop_or_vloop, forpr
21
21
isload (opp) && all (in (loopdependencies (opp)), unrolleddeps)
22
22
end
23
23
end
24
- function check_if_remfirst (ls, ua)
25
- usorig = ls. unrollspecification[]
24
+ function check_if_remfirst (ls:: LoopSet , ua:: UnrollArgs )
25
+ usorig = ls. unrollspecification
26
26
@unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
27
27
u₁loop = getloop (ls, u₁loopsym)
28
28
u₂loop = getloop (ls, u₂loopsym)
@@ -39,20 +39,23 @@ function sub_fmas(ls::LoopSet, op::Operation, ua::UnrollArgs)
39
39
! (load_constrained (op, u₁loopsym, u₂loopsym, vloopsym) || check_if_remfirst (ls, ua))
40
40
end
41
41
42
- struct FalseCollection end
43
- Base. getindex (:: FalseCollection , i... ) = false
44
- function parent_unroll_status (op:: Operation , u₁loop:: Symbol )
45
- map (opp -> isunrolled_sym (opp, u₁loop), parents (op)), fill (false , length (parents (op)))
42
+ function parent_unroll_status (op:: Operation , u₁loop:: Symbol , us:: UnrollSpecification )
43
+ parentsop = parents (op)
44
+ u2 = fill (false , length (parentsop))
45
+ u1 = similar (u2)
46
+ for i ∈ eachindex (parentsop)
47
+ u1[i] = isunrolled_sym (parentsop[i], u₁loop, us)
48
+ end
49
+ u1, u2
46
50
end
47
- function parent_unroll_status (op:: Operation , u₁loop:: Symbol , u₂loop:: Symbol , vloop:: Symbol , u₂max:: Int )
48
- # u₂max ≥ 0 || return parent_unroll_status(op, u₁loop)
49
- u₂max == - 1 && return parent_unroll_status (op, u₁loop)
51
+ function parent_unroll_status (op:: Operation , u₁loop:: Symbol , u₂loop:: Symbol , vloop:: Symbol , u₂max:: Int , us:: UnrollSpecification )
52
+ u₂max == - 1 && return parent_unroll_status (op, u₁loop, us)
50
53
vparents = parents (op);
51
54
# parent_names = Vector{Symbol}(undef, length(vparents))
52
55
parents_u₁syms = Vector {Bool} (undef, length (vparents))
53
56
parents_u₂syms = Vector {Bool} (undef, length (vparents))
54
57
for i ∈ eachindex (vparents)
55
- parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym (vparents[i], u₁loop, u₂loop, vloop)# , u₂max)
58
+ parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym (vparents[i], u₁loop, u₂loop, vloop, us )# , u₂max)
56
59
end
57
60
# parent_names, parents_u₁syms, parents_u₂syms
58
61
parents_u₁syms, parents_u₂syms
@@ -100,7 +103,10 @@ vecunrolllen(::Type{VecUnroll{N,W,T,V}}) where {N,W,T,V} = (N::Int + 1)
100
103
vecunrolllen (_) = - 1
101
104
function ifelselastexpr (hasf:: Bool , M:: Int , vargtypes, K:: Int , S:: Int , maskearly:: Bool )
102
105
q = Expr (:block , Expr (:meta ,:inline ))
103
- vargs = map (k -> Symbol (:varg_ ,k), 1 : K)
106
+ vargs = Vector {Symbol} (undef, K)
107
+ for k ∈ 1 : K
108
+ vargs[k] = Symbol (:varg_ ,k)
109
+ end
104
110
lengths = Vector {Int} (undef, K);
105
111
for k ∈ 1 : K
106
112
lengths[k] = l = vecunrolllen (vargtypes[k])
@@ -295,7 +301,7 @@ function parent_op_name(
295
301
parent, u
296
302
end
297
303
function getuouterreduct (ls:: LoopSet , op:: Operation , suffix)
298
- us = ls. unrollspecification[]
304
+ us = ls. unrollspecification
299
305
if us. vloopnum === us. u₁loopnum # unroll u₂
300
306
suffix
301
307
else # unroll u₁
306
312
function getu₁full (ls:: LoopSet , u₁:: Int )
307
313
Ureduct = ureduct (ls)
308
314
ufull = if Ureduct == - 1 # no reducing
309
- ls. unrollspecification[] . u₁
315
+ ls. unrollspecification. u₁
310
316
else
311
317
Ureduct
312
318
end
@@ -326,9 +332,9 @@ function getu₁forreduct(ls::LoopSet, op::Operation, u₁::Int)
326
332
end
327
333
if isu₁unrolled (op)
328
334
return u₁
329
- elseif (ls. unrollspecification[] . u₂ != - 1 ) && length (ls. outer_reductions) > 0
335
+ elseif (ls. unrollspecification. u₂ != - 1 ) && length (ls. outer_reductions) > 0
330
336
# then `ureduct` doesn't tell us what we need, so
331
- return ls. unrollspecification[] . u₁
337
+ return ls. unrollspecification. u₁
332
338
else # we need to find u₁-full
333
339
return getu₁full (ls, u₁)
334
340
end
@@ -357,13 +363,12 @@ function lower_compute!(
357
363
instr = instruction (op)
358
364
parents_op = parents (op)
359
365
nparents = length (parents_op)
360
- # __u₂max = ls.unrollspecification[] .u₂
366
+ # __u₂max = ls.unrollspecification.u₂
361
367
# TODO : perhaps allos for swithcing unrolled axis again
362
- # mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
363
- mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix)
368
+ mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
364
369
opunrolled = u₁unrolledsym || isu₁unrolled (op)
365
- # parent_names, parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, suffix)
366
- parents_u₁syms, parents_u₂syms = parent_unroll_status (op, u₁loopsym, u₂loopsym, vloopsym, u₂max)
370
+ us = ls . unrollspecification
371
+ parents_u₁syms, parents_u₂syms = parent_unroll_status (op, u₁loopsym, u₂loopsym, vloopsym, u₂max, us )
367
372
# tiledouterreduction = if num_loops(ls) == 1# (suffix == -1)# || (vloopsym === u₂loopsym)
368
373
tiledouterreduction = if (suffix == - 1 )# || (vloopsym === u₂loopsym)
369
374
suffix_ = Symbol (" " )
@@ -411,8 +416,8 @@ function lower_compute!(
411
416
# if isreduct
412
417
# @show u₁unrolledsym, u₂unrolledsym, isu₁unrolled(op), isu₂unrolled(op) op
413
418
# end
414
- if Base. libllvm_version < v " 11.0.0" && (suffix ≠ - 1 ) && isreduct# && (iszero(suffix) || (ls.unrollspecification[] .u₂ - 1 == suffix))
415
- # if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification[] .u₂ - 1 == suffix))
419
+ if Base. libllvm_version < v " 11.0.0" && (suffix ≠ - 1 ) && isreduct# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
420
+ # if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
416
421
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
417
422
instrfid = findfirst (Base. Fix2 (=== ,instr. instr), (:vfmadd_fast , :vfnmadd_fast , :vfmsub_fast , :vfnmsub_fast ))
418
423
# instrfid = findfirst(isequal(instr.instr), (:vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
@@ -449,13 +454,13 @@ function lower_compute!(
449
454
# modsuffix = ((u + suffix*(Uiter + 1)) & 7)
450
455
isouterreduct = true
451
456
# if u₁unrolledsym
452
- # modsuffix = ls.unrollspecification[] .u₁#getu₁full(ls, u₁)#u₁
457
+ # modsuffix = ls.unrollspecification.u₁#getu₁full(ls, u₁)#u₁
453
458
# Symbol(mangledvar(op), '_', modsuffix)
454
459
# else
455
460
if u₁unrolledsym
456
461
modsuffix = 0
457
462
else
458
- modsuffix = suffix % ls. ureduct[]
463
+ modsuffix = suffix % ls. ureduct
459
464
end
460
465
Symbol (mangledvar (op), modsuffix)
461
466
# end
@@ -538,7 +543,7 @@ function lower_compute!(
538
543
# end
539
544
# push!(q.args, (isreduct, u₁, (!u₁unrolledsym), isu₁unrolled(op), dopartialmap, varsym))
540
545
if maskreduct
541
- ifelsefunc = if ls . unrollspecification[] . u₁ == 1
546
+ ifelsefunc = if us . u₁ == 1
542
547
:ifelse # don't need to be fancy
543
548
elseif (u₁loopsym != = vloopsym)
544
549
:ifelsepartial # ifelse all the early ones
@@ -574,9 +579,9 @@ function lower_compute!(
574
579
make_partial_map! (instrcall, selfopname, u₁, selfdepreduce)
575
580
end
576
581
elseif selfdep != 0 && (dopartialmap ||
577
- (isouterreduct && (opunrolled) && (u₁ < ls . unrollspecification[] . u₁)) ||
582
+ (isouterreduct && (opunrolled) && (u₁ < us . u₁)) ||
578
583
(isreduct & (u₁ > 1 ) & (! u₁unrolledsym) & isu₁unrolled (op))) # TODO : DRY `selfdepreduce` definition
579
- # first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification[] .u₁)`):
584
+ # first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification.u₁)`):
580
585
# checks if we're in the "reduct" part of an outer reduction
581
586
#
582
587
# second possibility (`(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op))`):
0 commit comments