Skip to content

Commit e9a848d

Browse files
committed
Added a few tests.
1 parent f5bd59a commit e9a848d

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

src/add_compute.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ function search_tree(opv::Vector{Operation}, var::Symbol) # relies on cycles bei
7979
end
8080
false
8181
end
82+
function search_tree(opv::Vector{Operation}, var::Operation) # relies on cycles being forbidden
83+
for opp opv
84+
opp === var && return true
85+
search_tree(parents(opp), var) && return true
86+
end
87+
false
88+
end
8289
function update_reduction_status!(parentvec::Vector{Operation}, deps::Vector{Symbol}, parent::Symbol)
8390
for opp parentvec
8491
if name(opp) === parent

src/add_ifelse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
88
# for now, just simple 1-liners
99
@assert length(RHS.args) == 3 "if statements without an else cannot be assigned to a variable."
1010
condition = first(RHS.args)
11-
condop = if mpref === nothing
11+
condop = if isnothing(mpref)
1212
add_operation!(ls, gensym(:mask), condition, elementbytes, position)
1313
else
1414
add_operation!(ls, gensym(:mask), condition, mpref, elementbytes, position)
1515
end
1616
iftrue = RHS.args[2]
1717
if iftrue isa Expr
1818
trueop = add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
19-
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop))
19+
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop)) && !search_tree(parents(condop), trueop)
2020
trueop.instruction = Instruction(:conditionalload)
2121
push!(parents(trueop), condop)
2222
end
@@ -26,7 +26,7 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
2626
iffalse = RHS.args[3]
2727
if iffalse isa Expr
2828
falseop = add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
29-
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop))
29+
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop)) && !search_tree(parents(condop), falseop)
3030
falseop.instruction = Instruction(:conditionalload)
3131
push!(parents(falseop), negateop!(ls, condop, elementbytes))
3232
end

test/ifelsemasks.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,36 @@ T = Float32
181181
end
182182
end
183183

184+
function notacondload!(C, A, b)
185+
@inbounds for n 1:size(C,2), m 1:size(C,1)
186+
C[m,n] = A[m,n] * (b[n] > 0 ? b[n] : -b[n])
187+
end
188+
end
189+
function notacondloadavx!(C, A, b)
190+
@avx for n 1:size(C,2), m 1:size(C,1)
191+
C[m,n] = A[m,n] * (b[n] > 0 ? b[n] : -b[n])
192+
end
193+
end
194+
function condloadscalar!(C, A, c, b)
195+
@inbounds for n 1:size(C,2), m 1:size(C,1)
196+
C[m,n] = A[m,n] * (c[n] > 0 ? b[n] : 1) + c[n]
197+
end
198+
end
199+
function condloadscalaravx!(C, A, c, b)
200+
@avx for n 1:size(C,2), m 1:size(C,1)
201+
C[m,n] = A[m,n] * (c[n] > 0 ? b[n] : 1) + c[n]
202+
end
203+
end
204+
function maskedloadscalar!(C, A, b)
205+
@inbounds for n 1:size(C,2), m 1:size(C,1)
206+
C[m,n] = A[m,n] * (A[m,n] > 0 ? b[n] : 1)
207+
end
208+
end
209+
function maskedloadscalaravx!(C, A, b)
210+
@avx for n 1:size(C,2), m 1:size(C,1)
211+
C[m,n] = A[m,n] * (A[m,n] > 0 ? b[n] : 1)
212+
end
213+
end
184214
function AtmulBpos!(C, A, B)
185215
@inbounds for n 1:size(C,2), m 1:size(C,1)
186216
Cₘₙ = zero(eltype(C))
@@ -368,17 +398,32 @@ T = Float32
368398
A = rand(T(-100):T(100), K, M);
369399
B = rand(T(-100):T(100), K, N);
370400
C1 = rand(T(-100):T(100), M, N);
401+
b = rand(T(-100):T(100), N);
402+
d = rand(T(-100):T(100), N);
371403
else
372404
A = randn(T, K, M);
373405
B = randn(T, K, N);
374406
C1 = randn(T, M, N);
407+
b = randn(T, N);
408+
d = randn(T, N);
375409
end;
376410
C2 = copy(C1); C3 = copy(C1);
377411
AtmulBpos!(C1, A, B)
378412
AtmulBposavx!(C2, A, B)
379413
AtmulBpos_avx!(C3, A, B)
380414
@test C1 C2
381415
@test C1 C3
416+
C1 = similar(B);
417+
C2 = similar(B);
418+
notacondload!(C1, B, b)
419+
notacondloadavx!(C2, B, b)
420+
@test C1 C2
421+
maskedloadscalar!(C1, B, b)
422+
maskedloadscalaravx!(C2, B, b)
423+
@test C1 C2
424+
condloadscalar!(C1, B, b, d)
425+
condloadscalaravx!(C2, B, b, d)
426+
@test C1 C2
382427
end
383428

384429

0 commit comments

Comments
 (0)