Skip to content

Commit 46d9fb5

Browse files
committed
hoist some more constant loads
1 parent c9ec9f4 commit 46d9fb5

File tree

7 files changed

+116
-24
lines changed

7 files changed

+116
-24
lines changed

src/codegen/loopstartstopmanager.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ function cse_constant_offsets!(
324324
end
325325
end
326326
ops = operations(ls)
327+
# @show ar
327328
while licmoffset # repeat until we run out
328329
# ind = indices[ii]
329330
# indices are all the same across operations, so we look to the first for checking compatibility...

src/codegen/lower_constant.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ function lower_constant!(
158158
end
159159

160160
isconstantop(op::Operation) = (instruction(op) == LOOPCONSTANT) || (isconstant(op) && length(loopdependencies(op)) == 0)
161+
function isinitializedconst(op::Operation)
162+
if isconstant(op)
163+
return true
164+
elseif iscompute(op)
165+
for opp parents(op)
166+
isinitializedconst(opp) || return false
167+
end
168+
return true
169+
end
170+
false
171+
end
161172
function constantopname(op::Operation)
162173
instr = instruction(op)
163174
if instr === LOOPCONSTANT
@@ -181,7 +192,13 @@ end
181192
# @inline maybeconvert(::Type{T}, s::Number) where {T} = convert(T, s)
182193
# @inline maybeconvert(::Type{T}, s::T) where {T <: Number} = s
183194
# @inline maybeconvert(::Type, s) = s
184-
195+
function sizeequivalent_symint_expr(intval::Int, signed::Bool)
196+
if signed
197+
Expr(:call, lv(:sizeequivalentint), ELTYPESYMBOL, intval)
198+
else
199+
Expr(:call, lv(:sizeequivalentint), ELTYPESYMBOL, intval % UInt)
200+
end
201+
end
185202

186203
function lower_licm_constants!(ls::LoopSet)
187204
ops = operations(ls)
@@ -197,10 +214,8 @@ function lower_licm_constants!(ls::LoopSet)
197214
for (id,(intval,intsz,signed)) ls.preamble_symint
198215
if intsz == 1
199216
setop!(ls, ops[id], intval % Bool)
200-
elseif signed
201-
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentint), ELTYPESYMBOL, intval))
202217
else
203-
setop!(ls, ops[id], Expr(:call, lv(:sizeequivalentint), ELTYPESYMBOL, intval % UInt))
218+
setop!(ls, ops[id], sizeequivalent_symint_expr(intval, signed))
204219
end
205220
end
206221
for (id,floatval) ls.preamble_symfloat

src/codegen/lower_memory_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function parentind(ind::Symbol, op::Operation)
1+
function parentind(ind::Symbol, op::Union{Operation,ArrayReferenceMetaPosition})
22
for (id,opp) enumerate(parents(op))
33
name(opp) === ind && return id
44
end

src/modeling/determinestrategy.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ function unroll_no_reductions(ls, order, vloopsym)
253253
else
254254
max(1, min(4, round(Int, 1.75compute_rt / load_rt)))
255255
end
256+
# @show load_rt, store_rt, compute_rt, compute_l, u
256257
# u = min(u, max(1, (reg_count(ls) ÷ max(1,round(Int,rp)))))
257258
# commented out here is to decide to align loops
258259
# if memory_rt > compute_rt && isone(u) && (length(order) > 1) && (last(order) === vloopsym) && length(getloop(ls, last(order))) > 8W

src/modeling/operations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ function Operation(id::Int, var::Symbol, elementbytes::Int, instr, optype::Opera
354354
Operation( id, var, elementbytes, instr, optype, mpref.loopdependencies, mpref.reduceddeps, mpref.parents, mpref.mref )
355355
end
356356
Base.:(==)(x::ArrayReferenceMetaPosition, y::ArrayReferenceMetaPosition) = x.mref == y.mref
357+
parents(op::ArrayReferenceMetaPosition) = op.parents
357358
# Avoid memory allocations by using this for ops that aren't references
358359
const NOTAREFERENCE = ArrayReferenceMeta(ArrayReference(Symbol(""), Symbol[]),Bool[],Symbol(""))
359360
const NOTAREFERENCEMP = ArrayReferenceMetaPosition(NOTAREFERENCE, NOPARENTS, Symbol[], Symbol[],Symbol(""))
@@ -364,7 +365,7 @@ loopdependencies(ref::ArrayReferenceMetaPosition) = ref.loopdependencies
364365
reduceddependencies(ref::ArrayReferenceMetaPosition) = ref.reduceddeps
365366
arrayref(ref::ArrayReference) = ref
366367
arrayref(ref::ArrayReferenceMeta) = ref.ref
367-
arrayref(ref::ArrayReferenceMetaPosition) = ref.ref.ref
368+
arrayref(ref::ArrayReferenceMetaPosition) = ref.mref.ref
368369
arrayref(op::Operation) = op.ref.ref
369370
getindices(ref) = arrayref(ref).indices
370371
getoffsets(ref) = arrayref(ref).offsets

src/parse/add_constants.jl

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,83 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
4040
pushpreamble!(ls, Expr(:(=), name(op), var))
4141
rop
4242
end
43+
function ensure_constant_lowered!(ls::LoopSet, op::Operation)
44+
if iscompute(op)
45+
call = callexpr(instruction(op))
46+
for opp parents(op)
47+
ensure_constant_lowered!(ls, opp)
48+
push!(call.args, name(opp))
49+
end
50+
pushpreamble!(ls, Expr(:(=), name(op), call))
51+
elseif isconstant(op) & !isconstantop(op)
52+
opid = identifier(op)
53+
for (id, sym) ls.preamble_symsym
54+
if id == opid
55+
pushpreamble!(ls, Expr(:(=), name(op), sym))
56+
return
57+
end
58+
end
59+
for (id,(intval,intsz,signed)) ls.preamble_symint
60+
if id == opid
61+
if intsz == 1
62+
pushpreamble!(ls, Expr(:(=), name(op), intval % Bool))
63+
elseif signed
64+
pushpreamble!(ls, Expr(:(=), name(op), intval))
65+
else
66+
pushpreamble!(ls, Expr(:(=), name(op), intval % UInt))
67+
end
68+
return
69+
end
70+
end
71+
for (id,floatval) ls.preamble_symfloat
72+
if id == opid
73+
pushpreamble!(ls, Expr(:(=), name(op), floatval))
74+
return
75+
end
76+
77+
end
78+
for (id,typ) ls.preamble_zeros
79+
if id == opid
80+
pushpreamble!(ls, Expr(:(=), name(op), staticexpr(0)))
81+
return
82+
end
83+
end
84+
for (id,f) ls.preamble_funcofeltypes
85+
if id == opid
86+
pushpreamble!(ls, Expr(:(=), name(op), Expr(:call, reduction_zero(f), Float64)))
87+
return
88+
end
89+
end
90+
end
91+
end
92+
function ensure_constant_lowered!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, ind::Symbol)
93+
length(loopdependencies(mpref)) == 0 && return
94+
for (id,opp) enumerate(parents(mpref))
95+
if name(opp) === ind
96+
ensure_constant_lowered!(ls, opp)
97+
end
98+
end
99+
return nothing
100+
end
43101
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
44-
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
45-
add_vptr!(ls, op)
46-
temp = gensym!(ls, "intermediateconstref")
47-
vloadcall = Expr(:call, lv(:_vload), mpref.mref.ptr)
48-
nindices = length(getindices(op))
49-
# getoffsets(op) .+= 1
50-
if nindices > 0
51-
dummyloop = first(ls.loops)
52-
push!(vloadcall.args, mem_offset(op, UnrollArgs(dummyloop, dummyloop, dummyloop, 0, 0, 0), fill(false,nindices), true, ls))
102+
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
103+
add_vptr!(ls, op)
104+
temp = gensym!(ls, "intermediateconstref")
105+
vloadcall = Expr(:call, lv(:_vload), mpref.mref.ptr)
106+
nindices = length(getindices(op))
107+
# getoffsets(op) .+= 1
108+
if nindices > 0
109+
dummyloop = first(ls.loops)
110+
for ind getindicesonly(op)
111+
ensure_constant_lowered!(ls, mpref, ind)
53112
end
54-
push!(vloadcall.args, Expr(:call, lv(:False)), staticexpr(reg_size(ls)))
55-
pushpreamble!(ls, Expr(:(=), temp, vloadcall))
56-
pushpreamble!(ls, Expr(:(=), name(op), temp))
57-
pushpreamble!(ls, op, temp)
58-
pushop!(ls, op, temp)
113+
push!(vloadcall.args, mem_offset(op, UnrollArgs(dummyloop, dummyloop, dummyloop, 0, 0, 0), fill(false,nindices), true, ls))
114+
end
115+
push!(vloadcall.args, Expr(:call, lv(:False)), staticexpr(reg_size(ls)))
116+
pushpreamble!(ls, Expr(:(=), temp, vloadcall))
117+
pushpreamble!(ls, Expr(:(=), name(op), temp))
118+
pushpreamble!(ls, op, temp)
119+
pushop!(ls, op, temp)
59120
end
60121
# This version has loop dependencies. var gets assigned to sym when lowering.
61122
# value is what will get assigned within the loop.

src/parse/add_loads.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,25 @@ function add_load!(
4646
mpref = array_reference_meta!(ls, array, rawindices, elementbytes, var)
4747
add_load!(ls, mpref, elementbytes)
4848
end
49+
function load_is_constant(ls::LoopSet, mpref::ArrayReferenceMetaPosition)
50+
li = mpref.mref.loopedindex
51+
inds = getindicesonly(mpref)
52+
for i eachindex(li)
53+
li[i] && return false
54+
if (id = parentind(inds[i], mpref)) > 0
55+
isinitializedconst(parents(mpref)[id]) || return false
56+
end
57+
end
58+
true
59+
end
4960
function add_load!(
50-
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int
61+
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int
5162
)
52-
iszero(length(mpref.loopdependencies)) && return add_constant!(ls, mpref, elementbytes)
53-
op = Operation( ls, varname(mpref), elementbytes, :getindex, memload, mpref )
54-
add_load!(ls, op, true)
63+
if length(mpref.loopdependencies) == 0 || load_is_constant(ls, mpref)
64+
return add_constant!(ls, mpref, elementbytes)
65+
end
66+
op = Operation( ls, varname(mpref), elementbytes, :getindex, memload, mpref )
67+
add_load!(ls, op, true)
5568
end
5669

5770
# for use with broadcasting

0 commit comments

Comments
 (0)