Skip to content

Commit 31e640a

Browse files
committed
Changed where masks and eltype are defined for reconstructed loopsets.
1 parent c3e69df commit 31e640a

File tree

5 files changed

+35
-22
lines changed

5 files changed

+35
-22
lines changed

src/add_loads.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ end
3030

3131
# for use with broadcasting
3232
function add_simple_load!(
33-
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int, actualarray::Bool = true, broadcast::Bool = false
33+
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int,
34+
actualarray::Bool = true, broadcast::Bool = false
3435
)
3536
loopdeps = Symbol[s for s ref.indices]
3637
mref = ArrayReferenceMeta(
3738
ref, fill(true, length(loopdeps))
3839
)
40+
add_simple_load!(ls, var, mref, loopdeps, elementbytes, actualarray, broadcast)
41+
end
42+
function add_simple_load!(
43+
ls::LoopSet, var::Symbol, mref::ArrayReferenceMeta, loopdeps::Vector{Symbol},
44+
elementbytes::Int, actualarray::Bool = true, broadcast::Bool = false
45+
)
3946
op = Operation(
4047
length(operations(ls)), var, elementbytes,
4148
:getindex, memload, loopdeps,

src/lowering.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,20 +303,22 @@ function determine_eltype(ls::LoopSet)
303303
end
304304
promote_q
305305
end
306-
function determine_width(ls::LoopSet, vectorized::Symbol)
306+
function determine_width(
307+
ls::LoopSet, vectorized::Symbol
308+
)
307309
vloop = getloop(ls, vectorized)
308310
vwidth_q = Expr(:call, lv(:pick_vector_width_val))
309311
if isstaticloop(vloop)
310312
push!(vwidth_q.args, Expr(:call, Expr(:curly, :Val, length(vloop))))
311313
end
312314
# push!(vwidth_q.args, ls.T)
313-
# if length(ls.includedactualarrays) < 2
314-
push!(vwidth_q.args, ls.T)
315-
# else
316-
# for array ∈ ls.includedactualarrays
317-
# push!(vwidth_q.args, Expr(:call, :eltype, array))
318-
# end
319-
# end
315+
if length(ls.includedactualarrays) < 2
316+
push!(vwidth_q.args, ls.T)
317+
else
318+
for array ls.includedactualarrays
319+
push!(vwidth_q.args, Expr(:call, :eltype, array))
320+
end
321+
end
320322
vwidth_q
321323
end
322324
function init_remblock(unrolledloop::Loop, u₁loop::Symbol = unrolledloop.itersymbol)
@@ -358,10 +360,10 @@ function setup_preamble!(ls::LoopSet, us::UnrollSpecification)
358360
vectorized = order[vectorizedloopnum]
359361
# println("Setup preamble")
360362
W = ls.W; typeT = ls.T
361-
if length(ls.includedarrays) > 0
363+
if length(ls.includedactualarrays) > 0
362364
push!(ls.preamble.args, Expr(:(=), typeT, determine_eltype(ls)))
365+
push!(ls.preamble.args, Expr(:(=), W, determine_width(ls, vectorized)))
363366
end
364-
push!(ls.preamble.args, Expr(:(=), W, determine_width(ls, vectorized)))
365367
lower_licm_constants!(ls)
366368
pushpreamble!(ls, definemask(getloop(ls, vectorized), W))#, u₁ > 1 && u₁loopnum == vectorizedloopnum))
367369
for op operations(ls)

src/reconstruct_loopset.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,11 @@ function avx_loopset(instr, ops, arf, AM, LPSYM, LB, @nospecialize(vargs))
387387
nopsv = NOpsType[calcnops(ls, op) for op in ops]
388388
expandedv = [isexpanded(ls, ops, nopsv, i) for i eachindex(ops)]
389389
mrefs = create_mrefs!(ls, arf, arraysymbolinds, opsymbols, nopsv, expandedv, vargs)
390-
pushpreamble!(ls, Expr(:(=), ls.T, Expr(:call, :promote_type, [Expr(:call, :eltype, vptr(mref)) for mref mrefs]...)))
391-
# pushpreamble!(ls, Expr(:(=), ls.W, Expr(:call, lv(:pick_vector_width_val), [Expr(:call, :eltype, vptr(mref)) for mref ∈ mrefs]...)))
390+
append!(ls.includedactualarrays, (vptr(mref) for mref mrefs))
391+
# eltypes = [Expr(:call, :eltype, vptr(mref)) for mref ∈ mrefs]
392+
# pushpreamble!(ls, Expr(:(=), ls.T, Expr(:call, :promote_type, eltypes...)))
393+
# pushpreamble!(ls, Expr(:(=), ls.W, determine_width))
394+
# Expr(:call, lv(:pick_vector_width_val), [Expr(:call, :eltype, vptr(mref)) for mref ∈ mrefs]...)))
392395
num_params = num_arrays + num_parameters(AM)
393396
add_ops!(ls, instr, ops, mrefs, opsymbols, num_params, nopsv, expandedv, elementbytes)
394397
process_metadata!(ls, AM, length(arf))

src/split_loops.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function add_operation!(ls_new::LoopSet, included::Vector{Int}, ls::LoopSet, op:
1111
length(operations(ls_new)), name(op), op.elementbytes, instruction(op), op.node_type,
1212
loopdependencies(op), reduceddependencies(op), vparents, op.ref, reducedchildren(op)
1313
)
14+
accesses_memory(op) && addsetv!(ls_new.includedactualarrays, vptr(op.ref))
1415
push!(operations(ls_new), opnew)
1516
included[identifier(op)] = identifier(opnew)
1617
opnew

test/gemv.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,28 @@ using Test
178178
A = rand(R, M, K);
179179
x = rand(R, K);
180180
y1 = Vector{TC}(undef, M); y2 = similar(y1);
181-
mygemv!(y1, A, x)
182-
mygemvavx!(y2, A, x)
181+
mygemv!(y1, A, x);
182+
mygemvavx!(y2, A, x);
183183
@test y1 y2
184-
fill!(y2, -999.9); mygemv_avx!(y2, A, x)
184+
fill!(y2, -999.9); mygemv_avx!(y2, A, x);
185185
@test y1 y2
186186
fill!(y2, -999.9);
187187
mygemvavx_range!(y2, A, x)
188188
@test y1 y2
189189

190-
Abit = A .> 0.5
191-
fill!(y2, -999.9); mygemv_avx!(y2, Abit, x)
190+
Abit = A .> 0.5;
191+
fill!(y2, -999.9); mygemv_avx!(y2, Abit, x);
192192
@test y2 Abit * x
193-
xbit = x .> 0.5
194-
fill!(y2, -999.9); mygemv_avx!(y2, A, xbit)
193+
xbit = x .> 0.5;
194+
fill!(y2, -999.9); mygemv_avx!(y2, A, xbit);
195195
@test y2 A * xbit
196196

197197
B = rand(R, N, N);
198198
G1 = Matrix{TC}(undef, N, 1);
199199
G2 = similar(G1);
200200
# G3 = similar(G1);
201-
AtmulvB!(G1,B,1)
202-
AtmulvBavx1!(G2,B,1)
201+
AtmulvB!(G1,B,1);
202+
AtmulvBavx1!(G2,B,1);
203203
@test G1 G2
204204
fill!(G2, TC(NaN)); AtmulvBavx2!(G2,B,1);
205205

0 commit comments

Comments
 (0)