Skip to content

Commit cd1c612

Browse files
committed
A few more bug fixes
1 parent 7d44eac commit cd1c612

File tree

10 files changed

+98
-78
lines changed

10 files changed

+98
-78
lines changed

src/codegen/loopstartstopmanager.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -324,30 +324,27 @@ function offsetindex(dim::Int, ind::Int, scale::Int, isvectorized::Bool, incr::M
324324
push!(index.args, Expr(:call, lv(:Zero)))
325325
continue
326326
end
327-
if isvectorized
328-
if isone(scale)
329-
pushmulexpr!(index, VECTORWIDTHSYMBOL, incr)
330-
else
331-
push!(index.args, mulexpr(VECTORWIDTHSYMBOL, staticexpr(scale), incr))
332-
end
327+
if !isvectorized
328+
pushmulexpr!(index, scale, incr)
329+
elseif isone(scale)
330+
pushmulexpr!(index, VECTORWIDTHSYMBOL, incr)
333331
else
334-
pushmulexpr!(index, staticexpr(scale), incr)
332+
push!(index.args, mulexpr(VECTORWIDTHSYMBOL, scale, incr))
335333
end
336334
end
337335
index
338336
end
339337
function append_pointer_maxes!(
340338
loopstart::Expr, ls::LoopSet, ar::ArrayReferenceMeta, n::Int, submax::Int, isvectorized::Bool, stopindicator, incr::MaybeKnown
341339
)
340+
vptr_ar = vptr(ar)
342341
if submax < 2
343342
for sub 0:submax
344-
push!(loopstart.args, Expr(:(=), maxsym(vptr(ar), sub), pointermax(ls, ar, n, sub, isvectorized, stopindicator, incr)))
345-
# push!(loopstart.args, defpointermax(ls, ptrdefs[termind], n, sub, isvectorized, stopindicator))
343+
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), pointermax(ls, ar, n, sub, isvectorized, stopindicator, incr)))
346344
end
347345
else
348346
# @show n, getloop(ls, n) ar
349347
index, ind = pointermax_index(ls, ar, n, submax, isvectorized, stopindicator, incr)
350-
vptr_ar = vptr(ar)
351348
pointercompbase = maxsym(vptr_ar, submax)
352349
push!(loopstart.args, Expr(:(=), pointercompbase, Expr(:call, lv(:gesp), vptr_ar, index)))
353350
dim = length(getindicesonly(ar))
@@ -378,7 +375,7 @@ function append_pointer_maxes!(loopstart::Expr, ls::LoopSet, ar::ArrayReferenceM
378375
stop = last(loop)
379376
incr = step(loop)
380377
if isknown(start) & isknown(stop)
381-
return append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, startstopΔ(loop), incr)
378+
return append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, startstopΔ(loop)+1, incr)
382379
end
383380
looplensym = isone(start) ? getsym(stop) : loop.lensym
384381
append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, looplensym, incr)

src/codegen/lower_compute.jl

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ end
214214
N = maximum(lengths)
215215
Dlen = vecunrolllen(D)
216216
Sreduced = (S > 0) && (lengths[S] == -1) && N != -1
217-
# @show N, M, Sreduced
218217
if Sreduced
219218
M = N
220219
t = q
@@ -262,7 +261,6 @@ function parent_op_name(
262261
if n == tiledouterreduction
263262
parent = Symbol(parent, modsuffix)
264263
else
265-
# parent = variable_name(opp, suffix)
266264
if parents_u₂syms[n]
267265
parent = Symbol(parent, suffix_)
268266
end
@@ -273,16 +271,9 @@ function parent_op_name(
273271
else
274272
getu₁forreduct(ls, opp, u₁)
275273
end
276-
# u = parents_u₁syms[n] ? u₁ : 1
277274
parent = Symbol(parent, '_', u)
278275
end
279-
# if (tiledouterreduction == -1) && LoopVectorization.names(ls)[ls.unrollspecification[].u₁loopnum] ∈ reduceddependencies(opp)
280-
# u = u₁
281-
# else
282-
283-
# end
284276
if opisvectorized && isload(opp) && (!isvectorized(opp))
285-
# @show parents_u₁syms, parents_u₂syms, parent
286277
parent = Symbol(parent, "##broadcasted##")
287278
end
288279
parent
@@ -361,31 +352,24 @@ function lower_compute!(
361352
else
362353
newpname = Symbol(newparentname, '_', u₁)
363354
push!(q.args, Expr(:(=), newpname, Symbol(parentname, '_', u₁)))
364-
# @show newparentop op instruction(newparentop)
365355
reduce_expr!(q, newparentname, instruction(newparentop), u₁, -1, true)
366356
push!(q.args, Expr(:(=), Symbol(newparentname, '_', 1), Symbol(newparentname, "##onevec##")))
367357
end
368358
end
369359
end
370360
# if suffix === nothing# &&
371361
# end
372-
# if instr.instr === :div_fast
373-
# @show op, suffix, parents_u₂syms parents(op)
374-
# @show isu₂unrolled.(parents(op))
375-
# end
376362
# cache unroll and tiling check of parents
377363
# not broadcasted, because we use frequent checks of individual bools
378364
# making BitArrays inefficient.
379365
# parentsyms = [opp.variable for opp ∈ parents(op)]
380366
Uiter = opunrolled ? u₁ - 1 : 0
381-
# @show mvar, opunrolled, u₁, u₁loopsym, u₂loopsym
382367
isreduct = isreduction(op)
383368
if Base.libllvm_version < v"11.0.0" && (suffix -1) && isreduct# && (iszero(suffix) || (ls.unrollspecification[].u₂ - 1 == suffix))
384369
# if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification[].u₂ - 1 == suffix))
385370
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
386371
instrfid = findfirst(Base.Fix2(===,instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
387372
# instrfid = findfirst(isequal(instr.instr), (:vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
388-
# @show isreduct, instrfid, instr.instr sub_fmas(ls, op, ua)
389373
# want to instcombine when parent load's deps are superset
390374
# also make sure opp is unrolled
391375
if !(instrfid === nothing) && (opunrolled && u₁ > 1) && sub_fmas(ls, op, ua)
@@ -414,6 +398,7 @@ function lower_compute!(
414398
# for u ∈ 0:Uiter
415399
isouterreduct = false
416400
instrcall = callexpr(instr)
401+
dopartialmap = false
417402
varsym = if tiledouterreduction > 0 # then suffix ≠ -1
418403
# modsuffix = ((u + suffix*(Uiter + 1)) & 7)
419404
isouterreduct = true
@@ -426,28 +411,18 @@ function lower_compute!(
426411
if isreduct #(isanouterreduction(ls, op))
427412
# isouterreduct = true
428413
isouterreduct = isanouterreduction(ls, op)
429-
# @show op, isouterreduct, u₁, ls.unrollspecification[].u₂ != -1
430-
if isouterreduct
431-
Symbol(mvar, '_', getu₁full(ls, u₁))
432-
else
433-
Symbol(mvar, '_', getu₁forreduct(ls, op, u₁))
434-
end
414+
u₁reduct = isouterreduct ? getu₁full(ls, u₁) : getu₁forreduct(ls, op, u₁)
415+
dopartialmap = u₁reduct > u₁
416+
Symbol(mvar, '_', u₁reduct)
435417
else
436418
Symbol(mvar, '_', u₁)
437419
end
438420
else
439421
Symbol(mvar, '_', 1)
440422
end
423+
# @show getu₁forreduct(ls, op, u₁)
441424
selfopname = varsym
442-
# @show op, tiledouterreduction, isouterreduct
443-
# if name(op) === Symbol("##op#5631")
444-
# @show name(op), parents(op), name.(parents(op))
445-
# parent_name = parent_op_name(parents_op, 1, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
446-
# @show parent_name
447-
# end
448-
# @show selfopname, varsym, mvar, mangledvar(op)
449425
selfdep = 0
450-
# showexpr = false
451426
for n 1:nparents
452427
opp = parents_op[n]
453428
if isloopvalue(opp)
@@ -461,19 +436,19 @@ function lower_compute!(
461436
selfopname = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
462437
push!(instrcall.args, selfopname)
463438
else
464-
# @show name(parents_op[n]), name(op), mangledvar(parents_op[n]), mangledvar(op)
439+
# @show varsym
465440
push!(instrcall.args, varsym)
466441
end
467442
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (isouterreduction(ls, opp) != -1)
468443
# this checks if the parent is u₂ unrolled but this operation is not, in which case we need to reduce it.
469444
push!(instrcall.args, reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls)))
470445
else
471446
parent = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
472-
# @show parent, u₁, selfopname
473447
push!(instrcall.args, parent)
474448
end
475449
end
476450
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
451+
# push!(q.args, (isreduct, u₁, (!u₁unrolledsym), isu₁unrolled(op), dopartialmap, varsym))
477452
if maskreduct
478453
ifelsefunc = if ls.unrollspecification[].u₁ == 1
479454
:ifelse # don't need to be fancy
@@ -510,9 +485,9 @@ function lower_compute!(
510485
# @show op, isouterreduct, maskreduct, instr
511486
make_partial_map!(instrcall, selfopname, u₁, selfdepreduce)
512487
end
513-
elseif selfdep != 0 &&
488+
elseif selfdep != 0 && (dopartialmap ||
514489
(isouterreduct && (opunrolled) && (u₁ < ls.unrollspecification[].u₁)) ||
515-
(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op))
490+
(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op)))
516491
# first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification[].u₁)`):
517492
# checks if we're in the "reduct" part of an outer reduction
518493
#
@@ -524,7 +499,9 @@ function lower_compute!(
524499
# elseif
525500
end
526501
if instr.instr === :identity && isone(length(parents_op))
527-
push!(q.args, Expr(:(=), varsym, instrcall.args[2]))
502+
if instrcall.args[2] !== varsym
503+
push!(q.args, Expr(:(=), varsym, instrcall.args[2]))
504+
end
528505
elseif identifier(op) ls.outer_reductions && should_broadcast_op(op)
529506
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, instrcall)))
530507
else

src/codegen/lower_memory_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
229229
end
230230
end
231231
else
232-
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, AV, 1, M, 1)
232+
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, 0, 1, M, 1)
233233
end
234234
end
235235
function unrolledindex(op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool})

src/codegen/lower_store.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ function lower_store!(
122122
mvar = Symbol(variable_name(opp, ifelse(isu₂, suffix, -1)), '_', u)
123123
if all(op.ref.loopedindex)
124124
inds = unrolledindex(op, ua, mask, inds_calc_by_ptr_offset)
125-
126125
storeexpr = if reductfunc === Symbol("")
127126
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
128127
else

src/codegen/operation_evaluation_order.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,20 @@ end
2525

2626
function isnopidentity(ls::LoopSet, op::Operation, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol, u₂max::Int)
2727
parents_op = parents(op)
28-
if iscompute(op) && instruction(op).instr === :identity && name(first(parents_op)) === name(op) && isone(length(parents_op))
28+
if iscompute(op) && instruction(op).instr === :identity && isone(length(parents_op)) && name(first(parents_op)) === name(op)
2929
loopistiled = u₂max -1
30-
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loop, u₂loop, u₂max, Core.ifelse(isu₂unrolled(op), u₂max, -1))
31-
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, u₂max)
32-
if (u₁unrolledsym == first(parents_u₁syms)) && (isu₂unrolled(op) == parents_u₂syms[1])
30+
# mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loop, u₂loop, u₂max, Core.ifelse(isu₂unrolled(op), u₂max, -1))
31+
# parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, u₂max)
32+
# @show (u₁unrolledsym, first(parents_u₁syms)), (isu₂unrolled(op), parents_u₂syms[1])
33+
# @show op parents(op) isu₁unrolled(op), isu₁unrolled(only(parents(op)))
34+
# if (u₁unrolledsym == first(parents_u₁syms)) && (isu₂unrolled(op) == parents_u₂syms[1])
35+
opp = only(parents_op)
36+
if (isu₁unrolled(op) == isu₁unrolled(opp)) & (isu₂unrolled(op) == isu₂unrolled(opp))
3337
#TODO: identifer(first(parents_op)) ∉ ls.outer_reductions is going to miss a lot of cases
3438
#Should probably replace that with `DVec` (demoting Vec) types, that demote to scalar.
3539
#TODO: document (after finding out...) why only checking `isvectorized(first(parents_op))` -- why not `any(isvectorized, parents_op)`???
36-
if (isvectorized(first(parents_op)) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)
37-
op.instruction = reduction_to_scalar(instruction(first(parents_op)))
40+
if (isvectorized(opp) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)
41+
op.instruction = reduction_to_scalar(instruction(opp))
3842
op.mangledvariable = gensym(op.mangledvariable)
3943
false
4044
else

src/reconstruct_loopset.jl

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,45 +174,71 @@ end
174174
# sptrs::Expr, ls::LoopSet, ar::ArrayReferenceMeta, @nospecialize(_::Type{Core.LLVMPtr{T,0}}),
175175
function add_mref!(
176176
sptrs::Expr, ls::LoopSet, ar::ArrayReferenceMeta, @nospecialize(_::Type{Ptr{T}}),
177-
C::Int, B::Int, R::NTuple{N,Int}, name::Symbol
177+
C::Int, B::Int, sp::NTuple{N,Int}, name::Symbol
178178
) where {T,N}
179+
Tsym::Symbol = get(VectorizationBase.JULIA_TYPES, T) do
180+
Symbol(T)
181+
end
182+
add_mref_ptr!(sptrs, ls, ar, Tsym, C, B, sp, name)
183+
end
184+
function add_mref_ptr!(
185+
sptrs::Expr, ls::LoopSet, ar::ArrayReferenceMeta, Tsym::Symbol,
186+
C::Int, B::Int, sp::NTuple{N,Int}, name::Symbol
187+
) where {N}
179188
@assert B 0 "Batched arrays not supported yet."
180-
sp = rank_to_sortperm(R)
181189
# maybe no change needed? -- optimize common case
182190
column_major = ntuple(identity, N)
183191
li = ar.loopedindex;
184192
if sp === column_major || isone(length(li))
185193
return extract_gsp!(sptrs, name)
186194
end
187-
lic = copy(li);
188-
inds = getindices(ar); indsc = copy(inds);
189-
offsets = ar.ref.offsets; offsetsc = copy(offsets);
190-
195+
permute_mref!(ar, C, sp)
191196
# must now sort array's inds, and stack pointer's
192197
tmpsp = gensym(name)
193198
extract_gsp!(sptrs, tmpsp)
194199
strd_tup = Expr(:tuple)
195200
offsets_tup = Expr(:tuple)
201+
gf = GlobalRef(Core,:getfield)
202+
offsets = gensym(:offsets); strides = gensym(:strides)
203+
pushpreamble!(ls, Expr(:(=), offsets, Expr(:call, gf, tmpsp, QuoteNode(:offsets))))
204+
pushpreamble!(ls, Expr(:(=), strides, Expr(:call, gf, tmpsp, QuoteNode(:strd))))
205+
for (i, p) enumerate(sp)
206+
push!(strd_tup.args, Expr(:call, gf, strides, p, false))
207+
push!(offsets_tup.args, Expr(:call, gf, offsets, p, false))
208+
end
209+
sptype = Expr(:curly, lv(:StridedPointer), Tsym, N, (C == -1 ? -1 : 1), B, column_major)
210+
sptr = Expr(:call, sptype, Expr(:call, :pointer, tmpsp), strd_tup, offsets_tup)
211+
pushpreamble!(ls, Expr(:(=), name, sptr))
212+
nothing
213+
end
214+
function permute_mref!(ar::ArrayReferenceMeta, C::Int, sp::NTuple{N,Int}) where {N}
215+
sp === ntuple(identity, Val(N)) && return nothing
216+
li = ar.loopedindex; lic = copy(li);
217+
inds = getindices(ar); indsc = copy(inds);
218+
offsets = ar.ref.offsets; offsetsc = copy(offsets);
219+
strides = ar.ref.strides; stridesc = copy(strides);
196220
for (i, p) enumerate(sp)
197221
li[i] = lic[p]
198222
inds[i] = indsc[p]
199223
offsets[i] = offsetsc[p]
200-
push!(strd_tup.args, :($tmpsp.strd[$p]))
201-
# push!(offsets_tup.args, Expr(:call, lv(:Zero)))
202-
push!(offsets_tup.args, :($tmpsp.offsets[$p]))
224+
strides[i] = stridesc[p]
203225
end
204226
C == -1 && makediscontiguous!(getindices(ar))
205-
sptype = Expr(:curly, lv(:StridedPointer), T, N, (C == -1 ? -1 : 1), 1, column_major)
206-
sptr = Expr(:call, sptype, Expr(:call, :pointer, tmpsp), strd_tup, offsets_tup)
207-
pushpreamble!(ls, Expr(:(=), name, sptr))
208-
nothing
227+
return nothing
209228
end
210229
function add_mref!(
211230
sptrs::Expr, ::LoopSet, ::ArrayReferenceMeta, @nospecialize(_::Type{VectorizationBase.FastRange{T,F,S,O}}),
212231
::Int, ::Int, ::Any, name::Symbol
213232
) where {T,F,S,O}
214233
extract_gsp!(sptrs, name)
215234
end
235+
function create_mrefs!(
236+
ls::LoopSet, arf::Vector{ArrayRefStruct}, as::Vector{Symbol}, os::Vector{Symbol},
237+
nopsv::Vector{NOpsType}, expanded::Vector{Bool}, ::Type{Tuple{}}
238+
)
239+
length(arf) == 0 || throw(ArgumentError("Length of array ref vector should be 0 if there are no stridedpointers."))
240+
Vector{ArrayReferenceMeta}(undef, length(arf))
241+
end
216242
function create_mrefs!(
217243
ls::LoopSet, arf::Vector{ArrayRefStruct}, as::Vector{Symbol}, os::Vector{Symbol},
218244
nopsv::Vector{NOpsType}, expanded::Vector{Bool}, ::Type{VectorizationBase.GroupedStridedPointers{P,C,B,R,I,X,O}}
@@ -222,19 +248,26 @@ function create_mrefs!(
222248
# pushpreamble!(ls, Expr(:(=), sptrs, :(VectorizationBase.stridedpointers(getfield(vargs, 1, false)))))
223249
pushpreamble!(ls, Expr(:(=), sptrs, :(VectorizationBase.stridedpointers(getfield(var"#vargs#", 1, false)))))
224250
j = 0
251+
rank_to_sps = Vector{Tuple{Int,NTuple{<:Any,Int}}}(undef, length(arf))
225252
for i eachindex(arf)
226253
ar = ArrayReferenceMeta(ls, arf[i], as, os, nopsv, expanded)
227254
duplicate = false
228255
vptrar = vptr(ar)
229256
for k 1:i-1
230257
if vptr(mrefs[k]) === vptrar
231258
duplicate = true
259+
# if isassigned(rank_to_sps, k)
260+
Cₖ, sp = rank_to_sps[k]
261+
permute_mref!(ar, Cₖ, sp)
262+
# end
232263
break
233264
end
234265
end
235266
if !duplicate
236267
j += 1
237-
add_mref!(sptrs, ls, ar, P.parameters[j], C[j], B[j], R[j], vptr(ar))
268+
sp = rank_to_sortperm(R[j])
269+
rank_to_sps[i] = (C[j],sp)
270+
add_mref!(sptrs, ls, ar, P.parameters[j], C[j], B[j], sp, vptr(ar))
238271
end
239272
mrefs[i] = ar
240273
end
@@ -500,7 +533,11 @@ function avx_loopset(
500533
)
501534
ls = LoopSet(:LoopVectorization)
502535
# TODO: check outer reduction types instead
503-
elementbytes = sizeofeltypes(vargs[1].parameters[1].parameters)
536+
elementbytes = if length(vargs[1].parameters) > 0
537+
sizeofeltypes(vargs[1].parameters[1].parameters)
538+
else
539+
8
540+
end
504541
pushpreamble!(ls, :((var"#loop#bounds#", var"#vargs#") = var"#lv#tuple#args#"))
505542
add_loops!(ls, LPSYM, LB)
506543
resize!(ls.loop_order, ls.loopsymbol_offsets[end])

0 commit comments

Comments
 (0)