Skip to content

Commit ce797c6

Browse files
committed
Minor additional progress.
1 parent 3b18781 commit ce797c6

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

src/condense_loopset.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ function argmeta_and_costs_description(ls::LoopSet, arraysymbolinds)
126126
)
127127
end
128128

129+
function loopset_return_value(ls::LoopSet)
130+
if length(ls.outer_reductions) == 1
131+
Expr(:call, :extract_data, Symbol(mangledvar(operations(ls)[ls.outer_reductions[1]]), 0))
132+
elseif length(ls.outer_reductions) > 1
133+
ret = Expr(:tuple)
134+
ops = operations(ls)
135+
for or ls.outer_reductions
136+
push!(ret.args, Expr(:call, :extract_data, Symbol(mangledvar(ops[or]), 0)))
137+
end
138+
ret
139+
else
140+
nothing
141+
end
142+
end
143+
144+
129145
# Try to condense in type stable manner
130146
function generate_call(ls::LoopSet)
131147
operation_descriptions = Expr(:curly, :Tuple)
@@ -144,5 +160,18 @@ function generate_call(ls::LoopSet)
144160
q
145161
end
146162

163+
function setup_call(ls::LoopSet)
164+
call = generate_call(ls)
165+
retv = loopset_return_value(ls)
166+
q = Expr(:block,gc_preserve(ls, Expr(:(=), retv, call)))
167+
for or ls.outer_reductions
168+
op = ls.operations[or]
169+
var = name(op)
170+
mvar = mangledvar(op)
171+
instr = instruction(op)
172+
push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
173+
end
174+
175+
end
147176

148177

src/lowering.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,11 +701,11 @@ end
701701
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
702702
for or ls.outer_reductions
703703
op = ls.operations[or]
704-
var = op.variable
704+
var = name(op)
705705
mvar = mangledvar(op)
706-
instr = op.instruction
706+
instr = instruction(op)
707707
reduce_expr!(q, mvar, instr, U)
708-
push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
708+
length(ls.opdict) == 0 || push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
709709
end
710710
end
711711
function gc_preserve(ls::LoopSet, q::Expr)

src/reconstruct_loopset.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ function avx_body(ops, arf, AM, LB, vargs)
177177
opsymbols = [gensym(:op) for _ eachindex(ops)]
178178
mrefs = create_mrefs(ls, arf, arraysymbolinds, opsymbols, vargs)
179179
add_ops!(ls, ops, mrefs, opsymbols, elementbytes)
180+
q = lower(ls)
181+
push!(q.args, loopset_return_value(ls))
182+
q
180183
end
181184

182185

@@ -185,6 +188,6 @@ end
185188
OperationStruct[OPS.parameters...],
186189
ArrayRefStruct[ARF.parameters...],
187190
AM.parameters, LB.parameters, vargs
188-
)
191+
)
189192
end
190193

0 commit comments

Comments
 (0)