Skip to content

Commit f7afbe4

Browse files
committed
Attempted to get LLVM to not calculate addresses using SIMD instructions + extracts for the convolution example, but failed Leaving that infrastructure there.
1 parent 563467b commit f7afbe4

File tree

4 files changed

+144
-34
lines changed

4 files changed

+144
-34
lines changed

src/determinestrategy.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,16 +453,16 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
453453
end
454454
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
455455
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
456-
(vectorized == u₁loopsym || vectorized == u₂loopsym) && return false, false
457-
(isu₁unrolled(op) && isu₂unrolled(op)) || return false, false
458-
istranslation = false
456+
(vectorized == u₁loopsym || vectorized == u₂loopsym) && return 0, false
457+
(isu₁unrolled(op) && isu₂unrolled(op)) || return 0, false
458+
istranslation = 0
459459
inds = getindices(op); li = op.ref.loopedindex
460460
translationplus = false
461461
for i eachindex(li)
462462
if !li[i]
463463
opp = findparent(ls, inds[i + (first(inds) === Symbol("##DISCONTIGUOUSSUBARRAY##"))])
464464
if instruction(opp).instr (:+, :-) && u₁loopsym loopdependencies(opp) && u₂loopsym loopdependencies(opp)
465-
istranslation = true
465+
istranslation = i
466466
translationplus = instruction(opp).instr === :+
467467
end
468468
end
@@ -527,7 +527,7 @@ function maxnegativeoffset(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols
527527
end
528528
function load_elimination_cost_factor(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
529529
@unpack u₁loopsym, u₂loopsym = unrollsyms
530-
if first(isoptranslation(ls, op, unrollsyms))
530+
if !iszero(first(isoptranslation(ls, op, unrollsyms)))
531531
for loop ls.loops
532532
# If another loop is short, assume that LLVM will unroll it, in which case
533533
# we want to be a little more conservative in terms of register pressure.

src/lower_load.jl

Lines changed: 131 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -142,45 +142,150 @@ function lower_load_vectorized!(
142142
end
143143
nothing
144144
end
145+
function indisvectorized(ls::LoopSet, ind::Symbol, vectorized::Symbol)
146+
for op operations(ls)
147+
((op.variable === ind) && isvectorized(op)) && return true
148+
end
149+
false
150+
end
151+
152+
# function lower_load_for_optranslation!(
153+
# q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs{Int}, mask::Union{Nothing,Symbol,Unsigned}, translationind::Int
154+
# )
155+
# @unpack u₁, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix = td
156+
# iszero(suffix) || return
157+
158+
# gespinds = mem_offset(op, UnrollArgs(td, 0), indices_calculated_by_pointer_offsets(ls, op.ref), false)
159+
# ptr = vptr(op)
160+
# gptr = Symbol(ptr, "##GESPED##")
161+
# for i ∈ eachindex(gespinds.args)
162+
# if i != translationind
163+
# gespinds.args[i] = Expr(:call, lv(:extract_data), gespinds.args[i])
164+
# end
165+
# end
166+
# push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
167+
168+
# inds = Expr(:tuple)
169+
# ginds = Expr(:tuple)
170+
# indices = getindicesonly(op)
171+
172+
# for (i,ind) ∈ enumerate(indices)
173+
# if i == translationind # ind cannot be the translation ind
174+
# push!(inds.args, Expr(:call, lv(:Zero)))
175+
# push!(ginds.args, Expr(:call, Expr(:curly, lv(:Static), 1)))
176+
# elseif (ind === vectorized) || indisvectorized(ls, ind, vectorized)
177+
# push!(inds.args, _MMind(Expr(:call, lv(:Zero))))
178+
# push!(ginds.args, Expr(:call, lv(:Zero)))
179+
# else
180+
# push!(inds.args, Expr(:call, lv(:Zero)))
181+
# push!(ginds.args, Expr(:call, lv(:Zero)))
182+
# end
183+
# end
184+
# varbase = variable_name(op, 0)
185+
# vloadexpr = Expr(:call, lv(:vload), gptr, inds)
186+
# gespexpr = Expr(:(=), gptr, Expr(:call, lv(:gesp), gptr, ginds))
187+
# push!(q.args, Expr(:(=), Symbol(varbase, 0), vloadexpr))
188+
189+
# for u ∈ 1:u₁-1
190+
# push!(q.args, gespexpr)
191+
# push!(q.args, Expr(:(=), Symbol(varbase, u), vloadexpr))
192+
# end
193+
# # this takes care of u₂ == 0
194+
# offset = u₁
195+
# for u₂ ∈ 1:u₂max-1
196+
# varold = varbase
197+
# varbase = variable_name(op, u₂)
198+
# for u ∈ 0:u₁-2
199+
# push!(q.args, Expr(:(=), Symbol(varbase, u), Symbol(varold, u + 1)))
200+
# end
201+
# push!(q.args, gespexpr)
202+
# push!(q.args, Expr(:(=), Symbol(varbase, u₁ - 1), vloadexpr))
203+
# offset += 1
204+
# end
205+
# nothing
206+
# end
207+
208+
209+
function lower_load_for_optranslation!(
210+
q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs{Int}, mask::Union{Nothing,Symbol,Unsigned}, translationind::Int
211+
)
212+
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix = td
213+
iszero(suffix) || return
214+
215+
gespinds = mem_offset(op, UnrollArgs(td, 0), indices_calculated_by_pointer_offsets(ls, op.ref), false)
216+
ptr = vptr(op)
217+
gptr = Symbol(ptr, "##GESPED##")
218+
for i eachindex(gespinds.args)
219+
if i != translationind
220+
gespinds.args[i] = Expr(:call, lv(:extract_data), gespinds.args[i])
221+
end
222+
end
223+
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
224+
225+
inds = Expr(:tuple)
226+
indices = getindicesonly(op)
227+
228+
for (i,ind) enumerate(indices)
229+
if i == translationind # ind cannot be the translation ind
230+
push!(inds.args, Expr(:call, Expr(:curly, lv(:Static), 0)))
231+
elseif (ind === vectorized) || indisvectorized(ls, ind, vectorized)
232+
push!(inds.args, _MMind(Expr(:call, lv(:Zero))))
233+
else
234+
push!(inds.args, Expr(:call, lv(:Zero)))
235+
end
236+
end
237+
varbase = variable_name(op, 0)
238+
push!(q.args, Expr(:(=), Symbol(varbase, 0), Expr(:call, lv(:vload), gptr, copy(inds))))
239+
240+
for u 1:u₁-1
241+
inds.args[translationind] = Expr(:call, Expr(:curly, lv(:Static), u))
242+
push!(q.args, Expr(:(=), Symbol(varbase, u), Expr(:call, lv(:vload), gptr, copy(inds))))
243+
end
244+
# this takes care of u₂ == 0
245+
offset = u₁
246+
for u₂ 1:u₂max-1
247+
varold = varbase
248+
varbase = variable_name(op, u₂)
249+
for u 0:u₁-2
250+
push!(q.args, Expr(:(=), Symbol(varbase, u), Symbol(varold, u + 1)))
251+
end
252+
inds.args[translationind] = Expr(:call, Expr(:curly, lv(:Static), offset))
253+
push!(q.args, Expr(:(=), Symbol(varbase, u₁ - 1), Expr(:call, lv(:vload), gptr, copy(inds))))
254+
offset += 1
255+
end
256+
nothing
257+
end
145258

146259
# TODO: this code should be rewritten to be more "orthogonal", so that we're just combining separate pieces.
147260
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
148261
function lower_load!(
149262
q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs, mask::Union{Nothing,Symbol,Unsigned} = nothing
150263
)
151264
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = td
152-
if !isnothing(suffix) && suffix > 0
265+
if !isnothing(suffix) && ls.loadelimination[]
153266
istr, ispl = isoptranslation(ls, op, UnrollSymbols(u₁loopsym, u₂loopsym, vectorized))
154-
if istr && ispl
155-
varnew = variable_name(op, suffix)
156-
varold = variable_name(op, suffix - 1)
157-
for u 0:u₁-2
158-
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u + 1)))
159-
end
160-
umin = u₁ - 1
161-
elseif u₂loopsym !== vectorized
162-
mno, id = maxnegativeoffset(ls, op, u₂loopsym)
163-
if -suffix < mno < 0
164-
varnew = variable_name(op, suffix)
165-
varold = variable_name(operations(ls)[id], suffix + mno)
166-
opold = operations(ls)[id]
167-
if isu₁unrolled(op)
168-
for u 0:u₁-1
169-
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
267+
if !iszero(istr) & ispl
268+
lower_load_for_optranslation!(q, op, ls, td, mask, istr)
269+
elseif suffix > 0
270+
if u₂loopsym !== vectorized
271+
mno, id = maxnegativeoffset(ls, op, u₂loopsym)
272+
if -suffix < mno < 0
273+
varnew = variable_name(op, suffix)
274+
varold = variable_name(operations(ls)[id], suffix + mno)
275+
opold = operations(ls)[id]
276+
if isu₁unrolled(op)
277+
for u 0:u₁-1
278+
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
279+
end
280+
else
281+
push!(q.args, Expr(:(=), varnew, varold))
170282
end
171-
else
172-
push!(q.args, Expr(:(=), varnew, varold))
283+
return
173284
end
174-
return
175-
else
176-
umin = 0
177285
end
178-
else
179-
umin = 0
180286
end
181-
else
182-
umin = 0
183287
end
288+
umin = 0
184289
if isvectorized(op)
185290
lower_load_vectorized!(q, ls, op, td, mask, umin)
186291
else

src/lower_memory_common.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ unrolled loads are calculated as offsets with respect to an initial gesp. This h
6666
Therefore, unrolled === true results in inds being ignored.
6767
_mm means to insert `mm`s.
6868
"""
69-
function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vector{Bool})
69+
function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vector{Bool}, _mm::Bool = true)
7070
# @assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
7171
ret = Expr(:tuple)
7272
indices = getindicesonly(op)
@@ -81,9 +81,9 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
8181
# else
8282
if loopedindex[n]
8383
if inds_calc_by_ptr_offset[n]
84-
addoffset!(ret, offset, ind === vectorized)
84+
addoffset!(ret, offset, _mm & (ind === vectorized))
8585
else
86-
addoffset!(ret, ind, offset, ind === vectorized)
86+
addoffset!(ret, ind, offset, _mm & (ind === vectorized))
8787
end
8888
else
8989
addoffset!(ret, symbolind(ind, op, td), offset)
@@ -92,6 +92,7 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
9292
ret
9393
end
9494

95+
9596
function add_vectorized_offset!(ret::Expr, ind, offset, incr)
9697
if isone(incr)
9798
if iszero(offset)

src/lowering.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
238238
q = if nisvectorized
239239
# Expr(:block, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
240240
Expr(:block, Expr(:while, expect(tc), body))
241+
elseif isstaticloop(loop) && length(loop) 4
242+
qt = Expr(:block)
243+
foreach(_ -> push!(qt.args, body), 1:length(loop))
244+
qt
241245
else
242246
# Expr(:block, sl, assume(tc), Expr(:while, tc, body))
243247
push!(body.args, Expr(:||, expect(tc), Expr(:break)))

0 commit comments

Comments
 (0)