Skip to content

Commit 9d45ff7

Browse files
committed
hacky fix for CartesianIndices{0} stores that works for some cases
1 parent fb8e8bd commit 9d45ff7

File tree

7 files changed

+145
-72
lines changed

7 files changed

+145
-72
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ include("codegen/lowering.jl")
9191
include("codegen/split_loops.jl")
9292
include("codegen/lower_threads.jl")
9393
include("condense_loopset.jl")
94+
include("transforms.jl")
9495
include("reconstruct_loopset.jl")
9596
include("constructors.jl")
9697
include("user_api_conveniences.jl")

src/codegen/loopstartstopmanager.jl

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,36 @@ function uniquearrayrefs_csesummary(ls::LoopSet)
3737
end
3838

3939
function uniquearrayrefs(ls::LoopSet)
40-
uniquerefs = ArrayReferenceMeta[]
41-
includeinlet = Bool[]
42-
# for arrayref ∈ ls.refs_aliasing_syms
43-
for op operations(ls)
44-
arrayref = op.ref
45-
arrayref === NOTAREFERENCE && continue
46-
notunique = false
47-
isonlyname = true
48-
for ref uniquerefs
49-
notunique = sameref(arrayref, ref)
50-
isonlyname &= vptr(arrayref) !== vptr(ref)
51-
# if they're not the sameref, they may still have the same name
52-
# if they have different names, they're definitely not sameref
53-
notunique && break
54-
end
55-
if !notunique
56-
push!(uniquerefs, arrayref)
57-
push!(includeinlet, isonlyname)
58-
end
40+
uniquerefs = ArrayReferenceMeta[]
41+
includeinlet = Bool[]
42+
# for arrayref ∈ ls.refs_aliasing_syms
43+
for op operations(ls)
44+
op.instruction === DROPPEDCONSTANT && continue
45+
arrayref = op.ref
46+
arrayref === NOTAREFERENCE && continue
47+
notunique = false
48+
isonlyname = true
49+
for ref uniquerefs
50+
notunique = sameref(arrayref, ref)
51+
isonlyname &= vptr(arrayref) !== vptr(ref)
52+
# if they're not the sameref, they may still have the same name
53+
# if they have different names, they're definitely not sameref
54+
notunique && break
5955
end
60-
uniquerefs, includeinlet
56+
if !notunique
57+
push!(uniquerefs, arrayref)
58+
push!(includeinlet, isonlyname)
59+
end
60+
end
61+
uniquerefs, includeinlet
6162
end
6263

6364
otherindexunrolled(loopsym::Symbol, ind::Symbol, loopdeps::Vector{Symbol}) = (loopsym !== ind) && (loopsym loopdeps)
6465
function otherindexunrolled(ls::LoopSet, ind::Symbol, ref::ArrayReferenceMeta)
65-
us = ls.unrollspecification
66-
u₁sym = names(ls)[us.u₁loopnum]
67-
u₂sym = us.u₂loopnum > 0 ? names(ls)[us.u₂loopnum] : Symbol("##undefined##")
68-
otherindexunrolled(u₁sym, ind, loopdependencies(ref)) || otherindexunrolled(u₂sym, ind, loopdependencies(ref))
66+
us = ls.unrollspecification
67+
u₁sym = names(ls)[us.u₁loopnum]
68+
u₂sym = us.u₂loopnum > 0 ? names(ls)[us.u₂loopnum] : Symbol("##undefined##")
69+
otherindexunrolled(u₁sym, ind, loopdependencies(ref)) || otherindexunrolled(u₂sym, ind, loopdependencies(ref))
6970
end
7071
function multiple_with_name(n::Symbol, v::Vector{ArrayReferenceMeta})
7172
found = false
@@ -79,26 +80,26 @@ end
7980
# multiple_with_name(n::Symbol, v::Vector{ArrayReferenceMeta}) = sum(ref -> n === vptr(ref), v) > 1
8081
# TODO: DRY between indices_calculated_by_pointer_offsets and use_loop_induct_var
8182
function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMeta)
82-
indices = getindices(ar)
83-
ls.isbroadcast && return fill(false, length(indices))
84-
looporder = names(ls)
85-
offset = isdiscontiguous(ar)
86-
gespinds = Expr(:tuple)
87-
out = Vector{Bool}(undef, length(indices))
88-
li = ar.loopedindex
89-
for i eachindex(li)
90-
ii = i + offset
91-
ind = indices[ii]
92-
if (!li[i]) || (ind === CONSTANTZEROINDEX) || multiple_with_name(vptr(ar), ls.lssm.uniquearrayrefs) ||
93-
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
94-
out[i] = false
95-
elseif (isone(ii) && (first(looporder) === ind))
96-
out[i] = otherindexunrolled(ls, ind, ar)
97-
else
98-
out[i] = true
99-
end
83+
indices = getindices(ar)
84+
ls.isbroadcast && return fill(false, length(indices))
85+
looporder = names(ls)
86+
offset = isdiscontiguous(ar)
87+
gespinds = Expr(:tuple)
88+
out = Vector{Bool}(undef, length(indices))
89+
li = ar.loopedindex
90+
for i eachindex(li)
91+
ii = i + offset
92+
ind = indices[ii]
93+
if (!li[i]) || (ind === CONSTANTZEROINDEX) || multiple_with_name(vptr(ar), ls.lssm.uniquearrayrefs) ||
94+
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
95+
out[i] = false
96+
elseif (isone(ii) && (first(looporder) === ind))
97+
out[i] = otherindexunrolled(ls, ind, ar)
98+
else
99+
out[i] = true
100100
end
101-
out
101+
end
102+
out
102103
end
103104

104105
# @generated function set_first_stride(sptr::StridedPointer{T,N,C,B,R}) where {T,N,C,B,R}
@@ -139,8 +140,8 @@ end
139140

140141
# end
141142
function set_ref_loopedindex_and_ind!(ref::ArrayReferenceMeta, i::Int, ii::Int, li::Bool, ind::Symbol)
142-
ref.loopedindex[i] = li
143-
getindices(ref)[ii] = ind
143+
ref.loopedindex[i] = li
144+
getindices(ref)[ii] = ind
144145
end
145146
function set_all_to_constant_index!(
146147
ls::LoopSet, i::Int, ii::Int, indop::Operation, allarrayrefs::Vector{ArrayReferenceMeta},

src/condense_loopset.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,7 @@ make_crashy(q) = Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,
741741

742742
@inline vecmemaybe(x::NativeTypes) = x
743743
@inline vecmemaybe(x::VectorizationBase._Vec) = Vec(x)
744+
@inline vecmemaybe(x::VectorizationBase.Vec) = x
744745
@inline vecmemaybe(x::Tuple) = VectorizationBase.VecUnroll(x)
745746
@inline vecmemaybe(x::Mask) = x
746747

@@ -770,28 +771,31 @@ end
770771
# setup_call_ret!(ls, call, preserve)
771772
# end
772773
setup_outerreduct_preserve_mangler(op::Operation) = Symbol(mangledvar(op), "##onevec##")
774+
775+
function outer_reduction_to_scalar_reduceq!(q::Expr, op::Operation, var = name(op))
776+
instr = instruction(op)
777+
out = setup_outerreduct_preserve_mangler(op)
778+
if instr.instr :ifelse
779+
Expr(:call, reduction_scalar_combine(op), Expr(:call, lv(:vecmemaybe), out), var)
780+
else
781+
opinstr = ifelse_reduction(:IfElseReduced, op) do opv
782+
opvname = name(opv)
783+
oporig = gensym(opvname)
784+
pushfirst!(q.args, Expr(:(=), oporig, opvname))
785+
Expr(:call, lv(:vecmemaybe), setup_outerreduct_preserve_mangler(opv)), (oporig,)
786+
end
787+
Expr(:call, opinstr, Expr(:call, lv(:vecmemaybe), out), var)
788+
end
789+
end
773790
function setup_outerreduct_preserve(ls::LoopSet, call::Expr, preserve::Vector{Symbol})
774791
iszero(length(ls.outer_reductions)) && return gc_preserve(call, preserve)
775792
retv = loopset_return_value(ls, Val(false))
776793
q = Expr(:block, gc_preserve(Expr(:(=), retv, call), preserve))
777794
for or ls.outer_reductions
778795
op = ls.operations[or]
779-
var = name(op)
780796
# push!(call.args, Symbol("##TYPEOF##", var))
781-
instr = instruction(op)
782-
out = setup_outerreduct_preserve_mangler(op)
783-
reducq = if instr.instr :ifelse
784-
Expr(:call, reduction_scalar_combine(op), Expr(:call, lv(:vecmemaybe), out), var)
785-
else
786-
opinstr = ifelse_reduction(:IfElseReduced, op) do opv
787-
opvname = name(opv)
788-
oporig = gensym(opvname)
789-
pushfirst!(q.args, Expr(:(=), oporig, opvname))
790-
Expr(:call, lv(:vecmemaybe), setup_outerreduct_preserve_mangler(opv)), (oporig,)
791-
end
792-
Expr(:call, opinstr, Expr(:call, lv(:vecmemaybe), out), var)
793-
end
794-
push!(q.args, Expr(:(=), var, reducq))
797+
reducq = outer_reduction_to_scalar_reduceq!(q, op)
798+
push!(q.args, Expr(:(=), name(op), reducq))
795799
end
796800
q
797801
end

src/modeling/operations.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,19 +375,18 @@ getstrides(ref) = arrayref(ref).strides
375375

376376
isdiscontiguous(ref) = isdiscontiguous_inds(getindices(ref))
377377
function isdiscontiguous_inds(inds)
378-
# (first(inds) === DISCONTIGUOUS) || (first(inds) === CONSTANTZEROINDEX)
379-
first(inds) === DISCONTIGUOUS
378+
length(inds) == 0 ? false : (@inbounds(inds[begin]) === DISCONTIGUOUS)
380379
end
381380
function makediscontiguous!(inds)
382-
if iszero(length(inds)) || !isdiscontiguous_inds(inds)
383-
pushfirst!(inds, DISCONTIGUOUS)
384-
end
385-
nothing
381+
if iszero(length(inds)) || !isdiscontiguous_inds(inds)
382+
pushfirst!(inds, DISCONTIGUOUS)
383+
end
384+
nothing
386385
end
387386

388387
function getindicesonly(ref)
389-
indices = getindices(ref)
390-
@view(indices[isdiscontiguous(ref) + 1:end])
388+
indices = getindices(ref)
389+
@view(indices[isdiscontiguous(ref) + 1:end])
391390
end
392391
# function hasintersection(s1::Set{T}, s2::Set{T}) where {T}
393392
# for x ∈ s1

src/parse/add_constants.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ function ensure_constant_lowered!(ls::LoopSet, mpref::ArrayReferenceMetaPosition
9999
end
100100
return nothing
101101
end
102-
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
103-
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
104-
add_vptr!(ls, op)
102+
function add_constant_vload!(ls::LoopSet, op::Operation, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
105103
temp = gensym!(ls, "intermediateconstref")
106104
vloadcall = Expr(:call, lv(:_vload), mpref.mref.ptr)
107105
nindices = length(getindices(op))
@@ -117,6 +115,12 @@ function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementby
117115
pushpreamble!(ls, Expr(:(=), temp, vloadcall))
118116
pushpreamble!(ls, Expr(:(=), name(op), temp))
119117
pushpreamble!(ls, op, temp)
118+
return temp
119+
end
120+
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
121+
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
122+
add_vptr!(ls, op)
123+
temp = add_constant_vload!(ls, op, mpref, elementbytes)
120124
pushop!(ls, op, temp)
121125
end
122126
# This version has loop dependencies. var gets assigned to sym when lowering.

src/reconstruct_loopset.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,9 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
716716
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
717717
ls = _turbo_loopset(var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#".parameters, var"#V#".parameters, var"#UNROLL#")
718718
pushfirst!(ls.preamble.args, :(var"#lv#tuple#args#" = reassemble_tuple(Tuple{var"#LB#",var"#V#"}, var"#flattened#var#arguments#")))
719+
post = hoist_constant_memory_accesses!(ls)
719720
# return @show avx_body(ls, var"#UNROLL#")
720-
if last(var"#UNROLL#") > 1
721+
q = if last(var"#UNROLL#") > 1
721722
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = var"#UNROLL#"
722723
# wrap in `var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#"` in `Expr` to homogenize types
723724
avx_threads_expr(
@@ -729,5 +730,6 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
729730
# return @show avx_body(ls, var"#UNROLL#")
730731
avx_body(ls, var"#UNROLL#")
731732
end
733+
post === ls.preamble ? q : Expr(:block, q, post)
732734
# @show var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#"
733735
end

src/transforms.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# file for misc loopset transforms
2+
3+
function hoist_constant_memory_accesses!(ls::LoopSet)
4+
hoist_stores = false
5+
for op operations(ls)
6+
if isload(op)
7+
length(getindicesonly(op)) == 0 && host_constant_load!(ls, op)
8+
elseif isstore(op) && iszero(length(getindicesonly(op)))
9+
hoist_stores = true
10+
end
11+
end
12+
hoist_stores && return hoist_constant_memory_accesses_nocheck!(ls)
13+
ls.preamble
14+
end
15+
16+
function hoist_constant_memory_accesses_nocheck!(ls::LoopSet)
17+
post = Expr(:block)
18+
for op operations(ls)
19+
if isstore(op) && length(getindicesonly(op)) == 0
20+
hoist_constant_store!(post, ls, op)
21+
end
22+
end
23+
post
24+
end
25+
function hoist_constant_vload!(ls::LoopSet, op::Operation)
26+
op.instr = LOOPCONSTANT
27+
op.node_type = constant
28+
add_constant_vload!(ls, op, ArrayReferenceMetaPosition(op.ref, parents(op), loopdependencies(op), reduceddependencies(op), name(op)), elementbytes)
29+
end
30+
31+
function return_empty_reductinit(op::Operation, var::Symbol)
32+
for (i,opp) enumerate(parents(op))
33+
if (name(opp) === var) && (length(reduceddependencies(opp)) == 0) && (length(loopdependencies(opp)) == 0) && (length(children(opp)) == 1)
34+
return opp
35+
end
36+
opcheck = return_empty_reductinit(opp, var)
37+
opcheck === opp || return opcheck
38+
end
39+
return op
40+
end
41+
42+
43+
function hoist_constant_store!(q::Expr, ls::LoopSet, op::Operation)
44+
op.instruction = DROPPEDCONSTANT
45+
op.node_type = constant
46+
47+
opr = only(parents(op))
48+
while opr.instruction.instr === :identity
49+
opr.instruction = DROPPEDCONSTANT
50+
opr.node_type = constant
51+
opr = only(parents(opr))
52+
end
53+
push!(ls.outer_reductions, identifier(opr))
54+
55+
init = return_empty_reductinit(opr, name(opr)).instruction.instr
56+
pushpreamble!(ls, Expr(:(=), outer_reduct_init_typename(opr), Expr(:call, lv(:typeof), init)))
57+
q = Expr(:block)
58+
push!(q.args, Expr(:call, lv(:unsafe_store!), Expr(:call, lv(:pointer), op.ref.ptr), outer_reduction_to_scalar_reduceq!(q, opr, init)))
59+
length(q.args) == 0 || pushpreamble!(ls, q) # creating `Expr` and pushing because `outer_reduction_to_scalar_reduceq!` uses `pushfirst!(q.args`, and we don't want it at the start of the preamble
60+
return nothing
61+
end
62+

0 commit comments

Comments
 (0)