Skip to content

Commit 38162eb

Browse files
committed
Made module checks for adding masks general.
1 parent 019c005 commit 38162eb

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

src/LoopVectorization.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -329,26 +329,33 @@ end
329329
end
330330

331331
@noinline function add_masks(expr, masksym, reduction_symbols)
332+
# println("Called add masks!")
332333
postwalk(expr) do x
333-
if @capture(x, LoopVectorization.SIMDPirates.vstore!(ptr_, V_))
334-
return :(LoopVectorization.SIMDPirates.vstore!($ptr, $V, $masksym))
335-
elseif @capture(x, LoopVectorization.SIMDPirates.vload(V_, ptr_))
336-
return :(LoopVectorization.SIMDPirates.vload($V, $ptr, $masksym))
334+
if @capture(x, M_.vstore!(ptr_, V_))
335+
return :($M.vstore!($ptr, $V, $masksym))
336+
elseif @capture(x, M_.vload(V_, ptr_))
337+
return :($M.vload($V, $ptr, $masksym))
337338
# We mask the reductions, because the odds of them getting contaminated and therefore poisoning the results seems too great
338339
# for reductions to be practical. If what we're vectorizing is simple enough not to worry about contamination...then
339340
# it ought to be simple enough so we don't need @vectorize.
340-
elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vadd(reductionA_, B_ ) ) || @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vadd(B_, reductionA_ ) ) || @capture(x, reductionA_ = vadd(reductionA_, B_ ) ) || @capture(x, reductionA_ = vadd(B_, reductionA_ ) )
341-
return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vadd($reductionA, $B), $reductionA) )
342-
elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vmul(reductionA_, B_ ) ) || @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vmul(B_, reductionA_ ) ) || @capture(x, reductionA_ = vmul(reductionA_, B_ ) ) || @capture(x, reductionA_ = vmul(B_, reductionA_ ) )
343-
return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vmul($reductionA, $B), $reductionA) )
344-
elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vmuladd(B_, C_, reductionA_) ) || @capture(x, reductionA_ = vmuladd(B_, C_, reductionA_) )
345-
return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vmuladd($B, $C, $reductionA), $reductionA) )
346-
elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vfnmadd(B_, C_, reductionA_ ) ) || @capture(x, reductionA_ = vfnmadd(B_, C_, reductionA_ ) )
347-
return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vfnmadd($B, $C, $reductionA), $reductionA) )
348-
elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vsub(reductionA_, B_ ) ) || @capture(x, reductionA_ = vsub(reductionA_, B_ ) )
349-
return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vsub($reductionA, $B), $reductionA) )
350-
# elseif @capture(x, reductionA_ = LoopVectorization.SIMDPirates.vmul(reductionA_, B_ ) )
351-
# return :( $reductionA = LoopVectorization.SIMDPirates.vifelse($masksym, LoopVectorization.SIMDPirates.vmul($reductionA, $B), $reductionA) )
341+
elseif @capture(x, reductionA_ = M_.vadd(reductionA_, B_ ) ) || @capture(x, reductionA_ = M_.vadd(B_, reductionA_ ) ) || @capture(x, reductionA_ = vadd(reductionA_, B_ ) ) || @capture(x, reductionA_ = vadd(B_, reductionA_ ) )
342+
M === nothing && (M = :SIMDPirates)
343+
return :( $reductionA = $M.vifelse($masksym, $M.vadd($reductionA, $B), $reductionA) )
344+
elseif @capture(x, reductionA_ = M_.vmul(reductionA_, B_ ) ) || @capture(x, reductionA_ = M_.vmul(B_, reductionA_ ) ) || @capture(x, reductionA_ = vmul(reductionA_, B_ ) ) || @capture(x, reductionA_ = vmul(B_, reductionA_ ) )
345+
M === nothing && (M = :SIMDPirates)
346+
return :( $reductionA = $M.vifelse($masksym, $M.vmul($reductionA, $B), $reductionA) )
347+
elseif @capture(x, reductionA_ = M_.vmuladd(B_, C_, reductionA_) ) || @capture(x, reductionA_ = vmuladd(B_, C_, reductionA_) )
348+
M === nothing && (M = :SIMDPirates)
349+
return :( $reductionA = $M.vifelse($masksym, $M.vmuladd($B, $C, $reductionA), $reductionA) )
350+
elseif @capture(x, reductionA_ = M_.vfnmadd(B_, C_, reductionA_ ) ) || @capture(x, reductionA_ = vfnmadd(B_, C_, reductionA_ ) )
351+
M === nothing && (M = :SIMDPirates)
352+
return :( $reductionA = $M.vifelse($masksym, $M.vfnmadd($B, $C, $reductionA), $reductionA) )
353+
elseif @capture(x, reductionA_ = M_.vsub(reductionA_, B_ ) ) || @capture(x, reductionA_ = vsub(reductionA_, B_ ) )
354+
M === nothing && (M = :SIMDPirates)
355+
return :( $reductionA = $M.vifelse($masksym, $M.vsub($reductionA, $B), $reductionA) )
356+
# elseif @capture(x, reductionA_ = M_.vmul(reductionA_, B_ ) )
357+
# M === nothing && (M = :SIMDPirates)
358+
# return :( $reductionA = $M.vifelse($masksym, $M.vmul($reductionA, $B), $reductionA) )
352359
else
353360
return x
354361
end

0 commit comments

Comments
 (0)