Skip to content

Commit 21a9872

Browse files
committed
When unrolling outer reductions, currently don't switch unrolled axis contextually. May be nice to add that back in codegen eventually, but for now it's safer to disable. Also handle condvar & loopmasks correctly when the condvar isn't unrolled.
1 parent 6500968 commit 21a9872

File tree

8 files changed

+100
-62
lines changed

8 files changed

+100
-62
lines changed

src/codegen/lower_compute.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ end
4141

4242
struct FalseCollection end
4343
Base.getindex(::FalseCollection, i...) = false
44-
function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol)
45-
# map(opp -> isunrolled_sym(opp, u₁loop), parents(op)), map(opp -> isunrolled_sym(opp, u₂loop), parents(op))
44+
function parent_unroll_status(op::Operation, u₁loop::Symbol)
4645
map(opp -> isunrolled_sym(opp, u₁loop), parents(op)), fill(false, length(parents(op)))
4746
end
48-
function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol, u₂max::Int)
49-
u₂max 0 || return parent_unroll_status(op, u₁loop, u₂loop)
47+
function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂max::Int)
48+
# u₂max ≥ 0 || return parent_unroll_status(op, u₁loop)
49+
u₂max == -1 && return parent_unroll_status(op, u₁loop)
5050
vparents = parents(op);
5151
# parent_names = Vector{Symbol}(undef, length(vparents))
5252
parents_u₁syms = Vector{Bool}(undef, length(vparents))
5353
parents_u₂syms = Vector{Bool}(undef, length(vparents))
5454
for i eachindex(vparents)
55-
parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym(vparents[i], u₁loop, u₂loop, u₂max)
55+
parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym(vparents[i], u₁loop, u₂loop, vloop)#, u₂max)
5656
end
5757
# parent_names, parents_u₁syms, parents_u₂syms
5858
parents_u₁syms, parents_u₂syms
@@ -317,11 +317,14 @@ function lower_compute!(
317317
instr = instruction(op)
318318
parents_op = parents(op)
319319
nparents = length(parents_op)
320-
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, u₂max, suffix)
321-
opunrolled = u₁unrolledsym || u₁loopsym loopdependencies(op)
320+
# __u₂max = ls.unrollspecification[].u₂
321+
# TODO: perhaps allos for swithcing unrolled axis again
322+
# mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
323+
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
324+
opunrolled = u₁unrolledsym || isu₁unrolled(op)
322325
# parent_names, parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, suffix)
323-
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loopsym, u₂loopsym, u₂max)
324-
tiledouterreduction = if suffix == -1
326+
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loopsym, u₂loopsym, vloopsym, u₂max)
327+
tiledouterreduction = if (suffix == -1)# || (vloopsym === u₂loopsym)
325328
suffix_ = Symbol("")
326329
-1
327330
else
@@ -401,8 +404,15 @@ function lower_compute!(
401404
varsym = if tiledouterreduction > 0 # then suffix ≠ -1
402405
# modsuffix = ((u + suffix*(Uiter + 1)) & 7)
403406
isouterreduct = true
404-
modsuffix = suffix % tiled_outerreduct_unroll(ls)
405-
Symbol(mangledvar(op), modsuffix)
407+
# if u₁unrolledsym
408+
# modsuffix = ls.unrollspecification[].u₁#getu₁full(ls, u₁)#u₁
409+
# Symbol(mangledvar(op), '_', modsuffix)
410+
# else
411+
modsuffix = suffix % tiled_outerreduct_unroll(ls)
412+
Symbol(mangledvar(op), modsuffix)
413+
# end
414+
# dopartialmap = u₁ > 1
415+
406416
# Symbol(mvar, modsuffix)
407417
# elseif u₁unrolledsym
408418
# Symbol(mvar, u)
@@ -419,7 +429,6 @@ function lower_compute!(
419429
else
420430
Symbol(mvar, '_', 1)
421431
end
422-
# @show getu₁forreduct(ls, op, u₁)
423432
selfopname = varsym
424433
selfdep = 0
425434
for n 1:nparents
@@ -435,7 +444,6 @@ function lower_compute!(
435444
selfopname = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
436445
push!(instrcall.args, selfopname)
437446
else
438-
# @show varsym
439447
push!(instrcall.args, varsym)
440448
end
441449
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (isouterreduction(ls, opp) != -1)
@@ -447,6 +455,11 @@ function lower_compute!(
447455
end
448456
end
449457
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
458+
# if selfdep ≠ 0
459+
# @show mvar
460+
# # @show isu₁unrolled(op), u₁unrolledsym, u₁, u₂max
461+
# # @show selfdep, selfdepreduce#, op
462+
# end
450463
# push!(q.args, (isreduct, u₁, (!u₁unrolledsym), isu₁unrolled(op), dopartialmap, varsym))
451464
if maskreduct
452465
ifelsefunc = if ls.unrollspecification[].u₁ == 1
@@ -486,7 +499,7 @@ function lower_compute!(
486499
end
487500
elseif selfdep != 0 && (dopartialmap ||
488501
(isouterreduct && (opunrolled) && (u₁ < ls.unrollspecification[].u₁)) ||
489-
(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op)))
502+
(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op))) # TODO: DRY `selfdepreduce` definition
490503
# first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification[].u₁)`):
491504
# checks if we're in the "reduct" part of an outer reduction
492505
#

src/codegen/lower_constant.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ function lower_zero!(
4747
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, zerotyp::NumberType = zerotype(ls, op)
4848
)
4949
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
50-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, u₂max, suffix)
50+
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
51+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
5152
!opu₂ && suffix > 0 && return
5253
# TODO: for u₁, needs to consider if reducedchildren are u₁-unrolled
5354
# reductions need to consider reduct-status
@@ -94,7 +95,8 @@ function lower_constant!(
9495
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs
9596
)
9697
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
97-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, u₂max, suffix)
98+
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
99+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
98100
!opu₂ && suffix > 0 && return
99101
mvar = Symbol(mvar, '_', Core.ifelse(opu₁, u₁, 1))
100102
instruction = op.instruction

src/codegen/lower_memory_common.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ end
77
function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
88
id = parentind(ind, op)
99
id == -1 && return ind, op
10-
@unpack u₁, u₁loopsym, u₂loopsym, u₂max, suffix = td
10+
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
1111
parent = parents(op)[id]
12-
pvar, u₁op, u₂op = variable_name_and_unrolled(parent, u₁loopsym, u₂loopsym, u₂max, suffix)
12+
pvar, u₁op, u₂op = variable_name_and_unrolled(parent, u₁loopsym, u₂loopsym, vloopsym, suffix)
1313
Symbol(pvar, '_', Core.ifelse(u₁op, u₁, 1)), parent
1414
end
1515

@@ -57,7 +57,7 @@ function _addoffset!(ret::Expr, vloopstride, indexstride, index, offset, calcbyp
5757
addoffset!(ret, gethint(vloopstride)*gethint(indexstride), index, offset, calcbypointeroffset)
5858
else
5959
addoffset!(ret, mulexpr(vloopstride,indexstride), index, offset, calcbypointeroffset)
60-
end
60+
end
6161
end
6262
# multiply `index` by `indexstride`
6363
function addoffset!(ret::Expr, vloopstride, indexstride, index, offset, calcbypointeroffset::Bool) # 6 -> (5 or 6) args
@@ -163,7 +163,6 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
163163
end
164164
ret
165165
end
166-
isconditionalmemop(op::Operation) = (instruction(op).instr === :conditionalload) || (instruction(op).instr === :conditionalstore!)
167166
# function unrolled_curly(op::Operation, u₁::Int, u₁loopsym::Symbol, vectorized::Symbol, mask::Bool)
168167

169168
# interleave: `0` means `false`, positive means literal, negative means multiplier
@@ -328,34 +327,56 @@ end
328327
@generated function and_last(v::VecUnroll{N}, m) where {N}
329328
q = Expr(:block, Expr(:meta,:inline), :(vd = data(v)))
330329
t = Expr(:call, lv(:promote))
330+
gf = GlobalRef(Core, :getfield)
331331
for n 1:N
332-
push!(t.args, :(getfield(vd, $n, false)))
332+
push!(t.args, :($gf(vd, $n, false)))
333333
end
334-
push!(t.args, :(getfield(vd, $(N+1), false) & m))
334+
push!(t.args, :($gf(vd, $(N+1), false) & m))
335335
push!(q.args, Expr(:call, lv(:VecUnroll), t))
336336
q
337337
end
338338

339+
340+
isconditionalmemop(op::Operation) = (instruction(op).instr === :conditionalload) || (instruction(op).instr === :conditionalstore!)
339341
function add_memory_mask!(memopexpr::Expr, op::Operation, td::UnrollArgs, mask::Bool)
340342
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
341343
if isconditionalmemop(op)
342344
condop = last(parents(op))
343345
opu₂ = (suffix -1) && isu₂unrolled(op)
344-
condvar, condu₁unrolled = condvarname_and_unroll(condop, u₁loopsym, u₂loopsym, u₂max, suffix, opu₂)
346+
condvar, condu₁unrolled = condvarname_and_unroll(condop, u₁loopsym, u₂loopsym, vloopsym, suffix, opu₂)
345347
# if it isn't unrolled, then `m`
346348
u = condu₁unrolled ? u₁ : 1
347349
# u = isu₁unrolled(condop) ? u₁ : 1
348350
condvar = Symbol(condvar, '_', u)
349-
# @show condvar
351+
# If we need to apply `MASKSYMBOL` and the condvar
352+
# 2 condvar possibilities:
353+
# `VecUnroll` applied everywhere
354+
# single mask "broadcast"
355+
# 2 mask possibilities
356+
# u₁loopsym ≠ vloopsym, and we mask all
357+
# u₁loopsym == vloopsym, and we mask last
358+
# broadcast both, so can do so implicitly
359+
# this is true whether or not `condbroadcast`
350360
if !mask || (!isvectorized(op))
351361
push!(memopexpr.args, condvar)
352-
else
353-
# we only want to apply mask to `u₁`
362+
elseif (u₁loopsym vloopsym) | (u₁ == 1) # mask all equivalenetly
363+
push!(memopexpr.args, Expr(:call, lv(:&), condvar, MASKSYMBOL))
364+
# if the condition `(u₁loopsym ≢ vloopsym) | (u₁ == 1)` failed, we need to apply `MASKSYMBOL` only to last unroll.
365+
elseif !condu₁unrolled && isu₁unrolled(op) # condbroadcast
366+
# explicitly broadcast `condvar`, and apply `MASKSYMBOL` to end
367+
t = Expr(:call, lv(:promote))
368+
for um 1:u₁-1
369+
push!(t.args, condvar)
370+
end
371+
push!(t.args, Expr(:call, lv(:&), condvar, MASKSYMBOL))
372+
push!(memopexpr.args, Expr(:call, lv(:VecUnroll), t))
373+
else# !condbroadcast && !vecunrolled
354374
push!(memopexpr.args, Expr(:call, lv(:and_last), condvar, MASKSYMBOL))
355375
end
356376
elseif mask && isvectorized(op)
357377
push!(memopexpr.args, MASKSYMBOL)
358378
end
379+
nothing
359380
end
360381

361382
varassignname(var::Symbol, u::Int, isunrolled::Bool) = isunrolled ? Symbol(var, u) : var
@@ -372,7 +393,7 @@ function name_memoffset(var::Symbol, op::Operation, td::UnrollArgs, u₁unrolled
372393
name, mo
373394
end
374395

375-
function condvarname_and_unroll(cond, u₁loop, u₂loop, u₂max, suffix, opu₂)
376-
condvar, condu₁, condu₂ = variable_name_and_unrolled(cond, u₁loop, u₂loop, u₂max, Core.ifelse(opu₂, suffix, -1))
396+
function condvarname_and_unroll(cond::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, suffix::Int, opu₂::Bool)
397+
condvar, condu₁, condu₂ = variable_name_and_unrolled(cond, u₁loop, u₂loop, vloop, Core.ifelse(opu₂, suffix, -1))
377398
condvar, condu₁
378399
end

src/codegen/lower_store.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ function reduce_expr_u₂(toreduct::Symbol, instr::Instruction, u₂::Int)
2727
Expr(:call, lv(:reduce_tup), reduce_to_onevecunroll(instr), t)
2828
end
2929
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int, isu₁unrolled::Bool, isu₂unrolled::Bool)
30-
# if u₂ == -1
31-
# u₁u, u₂u = (true, false)
32-
# else
33-
# u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, _Umax)
34-
# end
3530
if isu₂unrolled# u₂ != -1
3631
_toreduct = Symbol(toreduct, 0)
3732
push!(q.args, Expr(:(=), _toreduct, reduce_expr_u₂(toreduct, instr, u₂)))
@@ -59,7 +54,7 @@ function lower_store_collection!(
5954

6055
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, vloop, u₂max, suffix = ua
6156
ops = operations(ls)
62-
57+
# __u₂max = ls.unrollspecification[].u₂
6358
nouter = length(idsformap)
6459

6560
t = Expr(:tuple)
@@ -68,7 +63,7 @@ function lower_store_collection!(
6863
for (i,(opid,_)) enumerate(idsformap)
6964
opp = first(parents(ops[opidmap[opid]]))
7065

71-
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, u₂max)
66+
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, vloopsym)#, __u₂max)
7267
u = Core.ifelse(isu₁, u₁, 1)
7368
mvar = Symbol(variable_name(opp, ifelse(isu₂, suffix, -1)), '_', u)
7469
# mvar = Symbol(variable_name(_op, suffix), '_', u)
@@ -122,7 +117,8 @@ function lower_store!(
122117
if (opp.instruction.instr === reductfunc) && isone(length(parents(opp)))
123118
opp = only(parents(opp))
124119
end
125-
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, u₂max)
120+
# __u₂max = ls.unrollspecification[].u₂
121+
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, vloopsym)#, __u₂max)
126122
u = isu₁ ? u₁ : 1
127123
mvar = Symbol(variable_name(opp, ifelse(isu₂, suffix, -1)), '_', u)
128124
if all(op.ref.loopedindex)
@@ -208,7 +204,7 @@ function lower_tiled_store!(blockq::Expr, op::Operation, ls::LoopSet, ua::Unroll
208204
throw("Operation $opp's instruction is $reductfunc, shouldn't be able to reach here.")
209205
# opp = only(parents(opp))
210206
end
211-
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, u₂)
207+
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, vloopsym)#, u₂)
212208
@assert isu₂
213209
# It's reasonable forthis to be `!isu₁`
214210
u = Core.ifelse(isu₁, u₁, 1)

src/codegen/lowering.jl

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23
# the `lowernonstore` and `lowerstore` options are there as a means of lowering all non-store operations before lowering the stores.
34
function lower!(
45
q::Expr, ops::AbstractVector{Operation}, ls::LoopSet, unrollsyms::UnrollSymbols, u₁::Int, u₂::Int,
@@ -443,7 +444,7 @@ function initialize_outer_reductions!(
443444
isvectorized = vectorized reduceddependencies(op)
444445
typeTr = ELTYPESYMBOL
445446
z = if isvectorized
446-
if Umax == 1
447+
if Umax == 1 || u₂ -1
447448
if reduct_zero === :zero
448449
Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, typeTr, rs)
449450
else
@@ -460,14 +461,14 @@ function initialize_outer_reductions!(
460461
Expr(:call, reduct_zero, typeTr)
461462
end
462463
mvar = variable_name(op, -1)
463-
# u1u, u2u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, u₂loop, u₂max)
464464
if u₂ == -1
465465
push!(q.args, Expr(:(=), Symbol(mvar, '_', _Umax), z))
466-
else
467-
u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, u₂)
466+
else#if isu₂unrolled(op) #& (us.vloopnum ≠ us.u₂loopnum) # tiled outer reduction, u₂unrolled
467+
# TODO: add `(us.vloopnum ≠ us.u₂loopnum)` check to avoid unrolling and vectorizing a reduction along the same axis
468+
u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, getloop(ls, us.vloopnum).itersymbol)#, u₂)
468469
if u₁u
469-
push!(q.args, Expr(:(=), Symbol(mvar, '_', _Umax), z))
470-
else
470+
push!(q.args, Expr(:(=), Symbol(mvar, '_', u₁), z))
471+
else
471472
for u 0:_Umax-1
472473
# push!(q.args, Expr(:(=), Symbol(mvar, '_', u), z))
473474
push!(q.args, Expr(:(=), Symbol(mvar, u), z))
@@ -536,21 +537,22 @@ end
536537
## This performs reduction to one `Vec`
537538
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
538539
us = ls.unrollspecification[]
539-
u₁f, u₂f = if us.u₂ == -1 # TODO: these multiple meanings make code hard to follow. Simplify.
540+
u₁f, u₂f = if us.u₂ == -1
540541
ifelse(U == -1, us.u₁, U), -1
541542
else
542543
us.u₁, U
543544
end
544545
# u₁loop, u₂loop = getunrolled(ls)
545546
u₁loop = getloop(ls, us.u₁loopnum).itersymbol
546547
u₂loop = getloop(ls, us.u₂loopnum).itersymbol
548+
vloop = getloop(ls, us.vloopnum).itersymbol
547549
for or ls.outer_reductions
548550
op = ls.operations[or]
549551
var = name(op)
550552
mvar = mangledvar(op)
551553
instr = instruction(op)
552-
u₁u, u₂u = isunrolled_sym(op, u₁loop, u₂loop, u₂f)
553-
reduce_expr!(q, mvar, instr, u₁f, u₂f, u₁u, u₂u)#isu₁unrolled(op))
554+
u₁u, u₂u = isunrolled_sym(op, u₁loop, u₂loop, vloop)#, u₂f)
555+
reduce_expr!(q, mvar, instr, u₁f, u₂f, u₁u, u₂u)
554556
if !iszero(length(ls.opdict))
555557
if (isu₁unrolled(op) | isu₂unrolled(op))
556558
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Symbol(mvar, "##onevec##"), var)))
@@ -838,7 +840,7 @@ It returns `true`/`false` for each loop, indicating whether they're unrolled.
838840
If there is a third argument, it will avoid unrolling that symbol along reductions if said symbol is part of the reduction chain.
839841
840842
"""
841-
function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol)
843+
function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol)
842844
u₁ild = isu₁unrolled(op)
843845
u₂ild = isu₂unrolled(op)
844846
(accesses_memory(op) | isloopvalue(op)) && return (u₁ild, u₂ild)
@@ -859,8 +861,16 @@ function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol)
859861
u₂reduced = u₂loop reductops
860862
# If they're being reduced, we want to only unroll the reduced variable along one of the two loops.
861863
# @show u₁reduced, u₂reduced
862-
if u₂reduced # if both are reduced, we unroll u₁
863-
true, false
864+
if u₂reduced
865+
if u₁reduced# if both are reduced, we unroll u₁
866+
if vloop === u₁loop
867+
false,true
868+
else
869+
true, false
870+
end
871+
else
872+
true,false
873+
end
864874
elseif u₁reduced
865875
false, true
866876
# true, false
@@ -873,24 +883,19 @@ function isunrolled_sym(op::Operation, u₁loop::Symbol)
873883
isu₁unrolled(op) || (isconstant(op) & (u₁loop reducedchildren(op)))
874884
end
875885

876-
# isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol) = (isunrolled_sym(op, u₁loop), false)
877-
# isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol, ::Int) = isunrolled_sym(op, u₁loop, u₂loop)
878-
function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol, u₂max::Int)
879-
((u₂max > 1) | accesses_memory(op)) ? isunrolled_sym(op, u₁loop, u₂loop) : (isunrolled_sym(op, u₁loop), false)
886+
function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂max::Int)
887+
((u₂max > 1) | accesses_memory(op)) ? isunrolled_sym(op, u₁loop, u₂loop, vloop) : (isunrolled_sym(op, u₁loop), false)
880888
end
881889

882890
function variable_name(op::Operation, suffix::Int)
883891
mvar = mangledvar(op)
884892
suffix == -1 ? mvar : Symbol(mvar, suffix, :_)
885893
end
886894

887-
function variable_name_and_unrolled(op::Operation, u₁loop::Symbol, u₂loop::Symbol, u₂max::Int, u₂iter::Int)
888-
# we require
889-
if (u₂iter == -1) | ((u₂max 1) & (!accesses_memory(op)))
890-
return mangledvar(op), isunrolled_sym(op, u₁loop), false
891-
end
892-
u₁op, u₂op = isunrolled_sym(op, u₁loop, u₂loop)
895+
# function variable_name_and_unrolled(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂max::Int, u₂iter::Int)
896+
function variable_name_and_unrolled(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂iter::Int)
897+
# u₁op, u₂op = isunrolled_sym(op, u₁loop, u₂loop, vloop, Core.ifelse(u₂iter == -1, 1, u₂max))
898+
u₁op, u₂op = isunrolled_sym(op, u₁loop, u₂loop, vloop)#, u₂max)
893899
mvar = u₂op ? variable_name(op, u₂iter) : mangledvar(op)
894-
# mvar = mangledvar(op)
895900
mvar, u₁op, u₂op
896901
end

0 commit comments

Comments
 (0)