Skip to content

Commit 349000c

Browse files
committed
Support chained comparisons, fixes #313 (tested), fix key error, which should fix #312 (untested)
1 parent b09316b commit 349000c

File tree

5 files changed

+93
-18
lines changed

5 files changed

+93
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.60"
4+
version = "0.12.61"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/modeling/graphs.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,9 +1050,25 @@ strip_op_linenumber_nodes(q::Expr) = only(filter(x -> !isa(x, LineNumberNode), q
10501050
function add_operation!(ls::LoopSet, LHS::Symbol, RHS::Symbol, elementbytes::Int, position::Int)
10511051
add_constant!(ls, RHS, ls.loopsymbols[1:position], LHS, elementbytes)
10521052
end
1053+
function add_comparison!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int)
1054+
Nargs = length(RHS.args)
1055+
@assert (Nargs 5) & isodd(Nargs)
1056+
p1 = add_assignment!(ls, gensym!(ls, "leftcmp"), RHS.args[1], elementbytes, position)::Operation
1057+
p2 = add_assignment!(ls, gensym!(ls, "middlecmp"), RHS.args[3], elementbytes, position)::Operation
1058+
cmpname = Nargs == 3 ? LHS : gensym!(ls, "cmp")
1059+
cmp = add_compute!(ls, cmpname, RHS.args[2], Operation[p1, p2], elementbytes)::Operation
1060+
for i 5:2:Nargs
1061+
pnew = add_assignment!(ls, gensym!(ls, "rightcmp"), RHS.args[i], elementbytes, position)::Operation
1062+
cmpchain = add_compute!(ls, gensym!(ls, "cmpchain"), RHS.args[i-1], Operation[p2, pnew], elementbytes)::Operation
1063+
cmpname = Nargs == i ? LHS : gensym!(ls, "cmp")
1064+
cmp = add_compute!(ls, cmpname, :&, [cmp, cmpchain], elementbytes)::Operation
1065+
p2 = pnew
1066+
end
1067+
return cmp
1068+
end
10531069
function add_operation!(
10541070
ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int
1055-
)
1071+
)
10561072
if RHS.head === :ref
10571073
add_load_ref!(ls, LHS, RHS, elementbytes)
10581074
elseif RHS.head === :call
@@ -1083,6 +1099,8 @@ function add_operation!(
10831099
# op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS, elementbytes, :numericconstant)
10841100
# pushpreamble!(ls, op, c)
10851101
# op
1102+
elseif Meta.isexpr(RHS, :comparison)
1103+
add_comparison!(ls, LHS, RHS, elementbytes, position)
10861104
else
10871105
throw(LoopError("Expression not recognized.", RHS))
10881106
end
@@ -1125,6 +1143,8 @@ function add_operation!(
11251143
# op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS_sym, elementbytes, :numericconstant)
11261144
# pushpreamble!(ls, op, c)
11271145
# op
1146+
elseif Meta.isexpr(RHS, :comparison, 5)
1147+
add_comparison!(ls, LHS, RHS, elementbytes, position)
11281148
else
11291149
throw(LoopError("Expression not recognized.", RHS))
11301150
end
@@ -1141,6 +1161,7 @@ function prepare_rhs_for_storage!(ls::LoopSet, RHS::Union{Symbol,Expr}, array, r
11411161
mpref.parents = cachedparents
11421162
op = add_store!(ls, mpref, elementbytes)
11431163
ls.syms_aliasing_refs[findfirst(==(mpref.mref), ls.refs_aliasing_syms)] = lrhs
1164+
ls.opdict[lrhs] = op
11441165
return op
11451166
end
11461167

src/parse/add_loads.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true)
3232
end
3333
end
3434
end
35-
if allmatch
36-
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
37-
end
35+
allmatch && isstore(opp) ? first(parents(opp)) : opp
3836
end
3937
add_vptr!(ls, op.ref.ref.array, vptr(op), actualarray)
4038
pushop!(ls, op, name(op))

src/parse/add_stores.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
function add_unique_store!(ls::LoopSet, op::Operation)
2-
add_vptr!(ls, op)
3-
pushop!(ls, op, name(op.ref))
2+
add_vptr!(ls, op)
3+
pushop!(ls, op, name(op.ref))
44
end
55
function cse_store!(ls::LoopSet, op::Operation)
6-
id = identifier(op)
7-
ls.operations[id] = op
8-
ls.opdict[op.variable] = op
6+
id = identifier(op)
7+
ls.operations[id] = op
8+
ls.opdict[op.variable] = op
99
end
1010
function add_store!(ls::LoopSet, op::Operation, add_pvar::Bool = !any(r -> r == op.ref, ls.refs_aliasing_syms))
11-
@assert isstore(op)
12-
if add_pvar
13-
push!(ls.syms_aliasing_refs, name(first(parents(op))))
14-
push!(ls.refs_aliasing_syms, op.ref)
15-
end
16-
id = op.identifier
17-
id == length(operations(ls)) ? add_unique_store!(ls, op) : cse_store!(ls, op)
11+
@assert isstore(op)
12+
if add_pvar
13+
push!(ls.syms_aliasing_refs, name(first(parents(op))))
14+
push!(ls.refs_aliasing_syms, op.ref)
15+
end
16+
id = op.identifier
17+
id == length(operations(ls)) ? add_unique_store!(ls, op) : cse_store!(ls, op)
1818
end
1919
function add_copystore!(
2020
ls::LoopSet, parent::Operation, mpref::ArrayReferenceMetaPosition, elementbytes::Int
@@ -89,7 +89,7 @@ function add_conditional_store!(ls::LoopSet, LHS, condop::Operation, storeop::Op
8989
id = length(ls.operations)
9090
@assert pvar ls.syms_aliasing_refs
9191
# if pvar ∉ ls.syms_aliasing_refs
92-
# FIXME properly handle CSE of conditional stores.
92+
# FIXME properly handle CSE of conditional stores.
9393
push!(ls.syms_aliasing_refs, pvar)
9494
push!(ls.refs_aliasing_syms, mref)
9595
storeparents = [storeop, condop]

test/ifelsemasks.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,46 @@ T = Float32
428428
return w
429429
end
430430

431+
432+
function mwe_v(dest, src, lut)
433+
# for i in eachindex(src)
434+
LoopVectorization.@turbo for i in eachindex(src)
435+
v = src[i]
436+
s = 128
437+
s2 = s | 64
438+
s = ifelse(lut[s2] <= v, s2, s)
439+
s2 = s | 32
440+
s = ifelse(lut[s2] <= v, s2, s)
441+
dest[i] = s
442+
end
443+
return dest
444+
end
445+
446+
function mwe_s(dest, src, lut)
447+
for i in eachindex(src)
448+
# LoopVectorization.@turbo for i in eachindex(src)
449+
v = src[i]
450+
s = 128
451+
s2 = s | 64
452+
s = ifelse(lut[s2] <= v, s2, s)
453+
s2 = s | 32
454+
s = ifelse(lut[s2] <= v, s2, s)
455+
dest[i] = s
456+
end
457+
return dest
458+
end
459+
460+
function turbocomparison!(m)
461+
@turbo for i in eachindex(m)
462+
m[i] = ifelse(0<m[i]<0.5, 0.0, m[i])
463+
end
464+
end
465+
function turbocomparison!(m, y)
466+
@turbo for i in eachindex(m)
467+
m[i] = ifelse(0<m[i]<y[i]<0.5, 0.0, m[i])
468+
end
469+
end
470+
431471
N = 117
432472
for T (Float32, Float64, Int32, Int64)
433473
@show T, @__LINE__
@@ -612,5 +652,21 @@ T = Float32
612652
@test barycentric_weight3(X) bX
613653
@test barycentric_weight4(X) bX
614654
end
655+
656+
let
657+
lut = let x = cumsum(rand(Float32, 256)./128); x[end] = Inf; x end;
658+
src = rand(Float32, N);
659+
660+
@test mwe_v(Vector{Int}(undef, N), src, lut) == mwe_s(Vector{Int}(undef, N), src, lut)
661+
@test mwe_v(Vector{Int32}(undef, N), src, lut) == mwe_s(Vector{Int32}(undef, N), src, lut)
662+
end
663+
664+
let m = rand(25, 25), y = rand(25,25), baseline5 = (@. ifelse(0 < m < 0.5, 0.0, m)), baseline7 = @. ifelse(0 < y < m < 0.5, 0.0, y)
665+
turbocomparison!(y, m)
666+
@test y == baseline7
667+
turbocomparison!(m)
668+
@test m == baseline5
669+
end
670+
615671
end
616672

0 commit comments

Comments
 (0)