@@ -363,19 +363,19 @@ function getu₁forreduct(ls::LoopSet, op::Operation, u₁::Int)
363
363
end
364
364
isidentityop (op:: Operation ) = iscompute (op) && (instruction (op). instr === :identity ) && (length (parents (op)) == 1 )
365
365
function reduce_parent! (q:: Expr , ls:: LoopSet , op:: Operation , opp:: Operation , parent:: Symbol )
366
- isvectorized (op) && return parent
367
- dependent_outer_reducts (ls, op) && return parent
368
- if isvectorized (opp)
369
- oppt = opp
370
- elseif isidentityop (opp)
371
- oppt = only (parents (opp))
372
- isvectorized (oppt) || return parent
373
- else
374
- return parent
375
- end
376
- newp = gensym (parent)
377
- push! (q. args, Expr (:(= ), newp, Expr (:call , lv (reduction_to_scalar (oppt. instruction)), parent)))
378
- newp
366
+ isvectorized (op) && return parent
367
+ dependent_outer_reducts (ls, op) && return parent
368
+ if isvectorized (opp)
369
+ oppt = opp
370
+ elseif isidentityop (opp)
371
+ oppt = only (parents (opp))
372
+ isvectorized (oppt) || return parent
373
+ else
374
+ return parent
375
+ end
376
+ newp = gensym (parent)
377
+ push! (q. args, Expr (:(= ), newp, Expr (:call , lv (reduction_to_scalar (oppt. instruction)), parent)))
378
+ newp
379
379
end
380
380
function lower_compute! (
381
381
q:: Expr , op:: Operation , ls:: LoopSet , ua:: UnrollArgs , mask:: Bool
@@ -410,6 +410,9 @@ function lower_compute!(
410
410
parentop. identifier, gensym (parentop. variable), parentop. elementbytes, parentop. instruction, parentop. node_type,
411
411
parentop. dependencies, parentop. reduced_deps, parentop. parents, parentop. ref, parentop. reduced_children
412
412
)
413
+ newparentop. vectorized = false
414
+ newparentop. u₁unrolled = false
415
+ newparentop. u₂unrolled = parents_u₂syms[i]
413
416
parentname = mangledvar (parentop)
414
417
newparentname = mangledvar (newparentop)
415
418
parents_op[i] = newparentop
0 commit comments