Skip to content

Commit 9ca5bdd

Browse files
committed
use indexed vload/vstores of vectorizables rather than incrementing them, because indexing is an easier API to support than incrementing the vectorizable object.
1 parent 44f19c0 commit 9ca5bdd

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

src/LoopVectorization.jl

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ end
340340
@noinline function add_masks(expr, masksym, reduction_symbols, default_module = :LoopVectorization)
341341
# println("Called add masks!")
342342
postwalk(expr) do x
343-
if @capture(x, M_.vstore!(ptr_, V_))
343+
if @capture(x, M_.vstore!(args__))
344344
M === nothing && (M = default_module)
345-
return :($M.vstore!($ptr, $V, $masksym))
346-
elseif @capture(x, M_.vload(V_, ptr_))
345+
return :($M.vstore!($(args...), $masksym))
346+
elseif @capture(x, M_.vload(args__))
347347
M === nothing && (M = default_module)
348-
return :($M.vload($V, $ptr, $masksym))
348+
return :($M.vload($(args...), $masksym))
349349
# We mask the reductions, because the odds of them getting contaminated and therefore poisoning the results seems too great
350350
# for reductions to be practical. If what we're vectorizing is simple enough not to worry about contamination...then
351351
# it ought to be simple enough so we don't need @vectorize.
@@ -399,12 +399,12 @@ end
399399
pA = indexed_expressions[A]
400400
end
401401
if i == declared_iter_sym
402-
return :($mod.vstore!($pA + $itersym, $B))
402+
return :($mod.vstore!($pA, $B, $itersym))
403403
elseif isa(i, Expr)
404404
contains_itersym, i2 = subsymbol(i, declared_iter_sym, itersym)
405-
return :($mod.vstore!($pA + $i2, $B))
405+
return :($mod.vstore!($pA, $B, $i2))
406406
else
407-
return :($mod.vstore!($pA + $i, $B))
407+
return :($mod.vstore!($pA, $B, $i))
408408
end
409409
elseif @capture(x, A_[i_,j_] = B_) || @capture(x, setindex!(A_, B_, i_, j_))
410410
if A keys(indexed_expressions)
@@ -426,7 +426,7 @@ end
426426
push!(loop_constants_quote.args, :( $stridesym = $stridexpr ))
427427
loop_constants_dict[stridexpr] = stridesym
428428
end
429-
return :($mod.vstore!($pA + $itersym + $ej*$stridesym, $B))
429+
return :($mod.vstore!($pA, $B, $itersym + $ej*$stridesym))
430430
else
431431
throw("Indexing columns with vectorized loop variable is not supported.")
432432
end
@@ -457,11 +457,11 @@ end
457457
end
458458
## check to see if we are to do a vector load or a broadcast
459459
if i == declared_iter_sym
460-
load_expr = :($mod.vload($V, $pA + $itersym ))
460+
load_expr = :($mod.vload($V, $pA, $itersym ))
461461
elseif isa(i, Expr)
462462
contains_itersym, i2 = subsymbol(i, declared_iter_sym, itersym)
463463
if contains_itersym
464-
load_expr = :($mod.vload($V, $pA + $i2 ))
464+
load_expr = :($mod.vload($V, $pA, $i2 ))
465465
else
466466
load_expr = :($mod.vbroadcast($V, $pA - 1 + $i))
467467
end
@@ -497,7 +497,7 @@ end
497497
push!(loop_constants_quote.args, :( $stridesym = $stridexpr ))
498498
loop_constants_dict[stridexpr] = stridesym
499499
end
500-
load_expr = :($mod.vload($V, $pA + $itersym + $ej*$stridesym))
500+
load_expr = :($mod.vload($V, $pA, $itersym + $ej*$stridesym))
501501
elseif j == declared_iter_sym
502502
throw("Indexing columns with vectorized loop variable is not supported.")
503503
else
@@ -527,8 +527,6 @@ end
527527
elseif @capture(x, A_[i_,:] .= B_)
528528
## Capture if there are multiple assignments...
529529
if A keys(indexed_expressions)
530-
# pA = esc(gensym(A))
531-
# pA = esc(Symbol(:p,A))
532530
pA = gensym(Symbol(:p,A))
533531
indexed_expressions[A] = pA
534532
else
@@ -539,11 +537,9 @@ end
539537
else
540538
isym = i
541539
end
542-
543540
br = gensym(:B)
544541
br2 = gensym(:B)
545542
coliter = gensym(:j)
546-
547543
stridexpr = :($mod.LoopVectorization.stride_row($A))
548544
if stridexpr keys(loop_constants_dict)
549545
stridesym = loop_constants_dict[stridexpr]
@@ -552,32 +548,13 @@ end
552548
push!(loop_constants_quote.args, :( $stridesym = $stridexpr ))
553549
loop_constants_dict[stridexpr] = stridesym
554550
end
555-
# numiterexpr = :(LoopVectorization.num_row_strides($A))
556-
# if numiterexpr ∈ keys(loop_constants_dict)
557-
# numitersym = loop_constants_dict[numiterexpr]
558-
# else
559-
# numitersym = gensym(:numiter)
560-
# push!(loop_constants_quote.args, :( $numitersym = $numiterexpr ))
561-
# loop_constants_dict[numiterexpr] = numitersym
562-
# end
563-
564551
expr = quote
565552
$br = $mod.LoopVectorization.extract_data.($B)
566-
567-
# for $coliter ∈ 0:$numitersym-1
568553
for $coliter 0:length($br)-1
569-
@inbounds $mod.vstore!($pA + $isym + $stridesym * $coliter, getindex($br,1+$coliter))
554+
@inbounds $mod.vstore!($pA, getindex($br,1+$coliter), $isym + $stridesym * $coliter)
570555
end
571556
end
572-
573557
return expr
574-
# elseif @capture(x, @nexprs N_ ex_)
575-
# # println("Macroexpanding x:", x)
576-
# # @show ex
577-
# # mx = Expr(:escape, Expr(:block, Any[ Base.Cartesian.inlineanonymous(ex,i) for i = 1:N ]...))
578-
# mx = Expr(:block, Any[ Base.Cartesian.inlineanonymous(ex,i) for i = 1:N ]...)
579-
# # println("Macroexpanded x:", mx)
580-
# return mx
581558
elseif @capture(x, zero(T_))
582559
return :(zero($V))
583560
elseif @capture(x, one(T_))
@@ -590,7 +567,6 @@ end
590567
push!(main_body.args, :($isymvec = $mod.vadd($isymvec, vbroadcast($V, $W)) ))
591568
return isymvec
592569
else
593-
# println("Returning x:", x)
594570
return x
595571
end
596572
end, VectorizationDict, false, mod) # macro_escape = false

0 commit comments

Comments
 (0)