Skip to content

Commit c7e617a

Browse files
committed
test += in more places, handle vectorized -> non-vectorized
1 parent c2a8949 commit c7e617a

File tree

5 files changed

+42
-22
lines changed

5 files changed

+42
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.9"
4+
version = "0.12.10"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_compute.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,22 @@ function getu₁forreduct(ls::LoopSet, op::Operation, u₁::Int)
333333
return getu₁full(ls, u₁)
334334
end
335335
end
336-
336+
isidentityop(op::Operation) = iscompute(op) && (instruction(op).instr === :identity) && (length(parents(op)) == 1)
337+
function reduce_parent!(q::Expr, ls::LoopSet, op::Operation, opp::Operation, parent::Symbol)
338+
isvectorized(op) && return parent
339+
dependent_outer_reducts(ls, op) && return parent
340+
if isvectorized(opp)
341+
oppt = opp
342+
elseif isidentityop(opp)
343+
oppt = only(parents(opp))
344+
isvectorized(oppt) || return parent
345+
else
346+
return parent
347+
end
348+
newp = gensym(parent)
349+
push!(q.args, Expr(:(=), newp, Expr(:call, lv(reduction_to_scalar(oppt.instruction)), parent)))
350+
newp
351+
end
337352
function lower_compute!(
338353
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, mask::Bool
339354
)
@@ -468,25 +483,26 @@ function lower_compute!(
468483
add_loopvalue!(instrcall, loopval, ua, u₁)
469484
elseif name(opp) === name(op)
470485
selfdep = n
471-
if ((isvectorized(first(parents_op)) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)) ||
486+
if ((isvectorized(opp) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)) ||
472487
(parents_u₁syms[n] != u₁unrolledsym) || (parents_u₂syms[n] != u₂unrolledsym)
473488

474489
selfopname, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
475490
# if (uₚ ≠ 0) & (uₚ ≠ u₁)
476491
# dopartialmap = true
477492
# end
493+
# @show selfopname, instr
478494
push!(instrcall.args, selfopname)
479495
else
480496
push!(instrcall.args, varsym)
481497
end
482498
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (isouterreduction(ls, opp) != -1)
483499
# this checks if the parent is u₂ unrolled but this operation is not, in which case we need to reduce it.
484-
push!(instrcall.args, reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls)))
500+
reduced_u₂ = reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls))
501+
reduced_u₂ = reduce_parent!(q, ls, op, opp, reduced_u₂)
502+
push!(instrcall.args, reduced_u₂)
485503
else
486504
parent, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
487-
# if name(op) === Symbol("##op#9536")
488-
# @show parent
489-
# end
505+
parent = reduce_parent!(q, ls, op, opp, parent)
490506
if (selfdep == 0) && search_tree(parents(opp), name(op))
491507
selfdep = n
492508
push!(instrcall.args, parent)
@@ -495,8 +511,6 @@ function lower_compute!(
495511
else
496512
push!(instrcall.args, parent)
497513
end
498-
499-
# @show uₚ, u₁, op
500514
end
501515
end
502516
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ Execute an `@avx` block. The block's code is represented via the arguments:
623623
@generated function _avx_!(
624624
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, var"#lv#tuple#args#"::Tuple{LB,V}
625625
) where {UNROLL, OPS, ARF, AM, LPSYM, LB, V}
626-
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
626+
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
627627
ls = _avx_loopset(OPS, ARF, AM, LPSYM, LB.parameters, V.parameters, UNROLL)
628628
# return @show avx_body(ls, UNROLL)
629629
if last(UNROLL) > 1

test/gemv.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using Test
3131
for j eachindex(x)
3232
yᵢ += A[i,j] * x[j]
3333
end
34-
y[i] = yᵢ
34+
y[i] += yᵢ
3535
end
3636
end
3737
function mygemvavx_range!(y, A, x)
@@ -246,8 +246,12 @@ using Test
246246
# y1 = Vector{TC}(undef, M); y2 = similar(y1);
247247

248248
mygemv!(y1, A, x);
249-
mygemvavx!(y2, A, x);
249+
fill!(y2, 0); mygemvavx!(y2, A, x);
250250
@test y1full y2full
251+
mygemvavx!(y2, A, x);
252+
@test y1 .* 2 y2
253+
mygemvavx!(y2, A, x);
254+
@test y1 .* 3 y2
251255
fill!(y2, -9999); mygemv_avx!(y2, A, x);
252256
@test y1full y2full
253257
fill!(y2, -9999);
@@ -263,12 +267,12 @@ using Test
263267
fill!(y2, -9999); mygemv_avx!(y2, Abit, x);
264268
@test y2 Abit * x
265269
end
266-
fill!(y2, -9999); mygemvavx!(y2, Abit, x);
270+
fill!(y2, 0); mygemvavx!(y2, Abit, x);
267271
@test y2 Abit * x
268272
xbit = x .> 0.5;
269273
fill!(y2, -9999); mygemv_avx!(y2, A, xbit);
270274
@test y2 A * xbit
271-
fill!(y2, -9999); mygemvavx!(y2, A, xbit);
275+
fill!(y2, 0); mygemvavx!(y2, A, xbit);
272276
@test y2 A * xbit
273277
end
274278

test/shuffleloadstores.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,17 @@ end
209209
@test Cc1 Cc2# ≈ Cc3
210210
end
211211
end
212-
M = 100
213-
G = 50
214-
J = 50
215-
H = 300
212+
if VERSION v"1.6.0-rc1"
213+
M = 100
214+
G = 50
215+
J = 50
216+
H = 300
216217

217-
A = Matrix(Tridiagonal(rand(G-1,G-1)));
218-
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1);
219-
ϕ = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1);
220-
@test issue209(M, G, J, H, A, B, ϕ) issue209_noavx(M, G, J, H, A, B, ϕ)
218+
A = Matrix(Tridiagonal(rand(G-1,G-1)));
219+
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1);
220+
ϕ = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1);
221+
@test issue209(M, G, J, H, A, B, ϕ) issue209_noavx(M, G, J, H, A, B, ϕ)
222+
end
221223
end
222224

223225

0 commit comments

Comments
 (0)