@@ -340,12 +340,12 @@ end
340
340
@noinline function add_masks (expr, masksym, reduction_symbols, default_module = :LoopVectorization )
341
341
# println("Called add masks!")
342
342
postwalk (expr) do x
343
- if @capture (x, M_. vstore! (ptr_, V_ ))
343
+ if @capture (x, M_. vstore! (args__ ))
344
344
M === nothing && (M = default_module)
345
- return :($ M. vstore! ($ ptr, $ V , $ masksym))
346
- elseif @capture (x, M_. vload (V_, ptr_ ))
345
+ return :($ M. vstore! ($ (args ... ) , $ masksym))
346
+ elseif @capture (x, M_. vload (args__ ))
347
347
M === nothing && (M = default_module)
348
- return :($ M. vload ($ V, $ ptr , $ masksym))
348
+ return :($ M. vload ($ (args ... ) , $ masksym))
349
349
# We mask the reductions, because the odds of them getting contaminated and therefore poisoning the results seems too great
350
350
# for reductions to be practical. If what we're vectorizing is simple enough not to worry about contamination...then
351
351
# it ought to be simple enough so we don't need @vectorize.
@@ -399,12 +399,12 @@ end
399
399
pA = indexed_expressions[A]
400
400
end
401
401
if i == declared_iter_sym
402
- return :($ mod. vstore! ($ pA + $ itersym , $ B ))
402
+ return :($ mod. vstore! ($ pA, $ B , $ itersym ))
403
403
elseif isa (i, Expr)
404
404
contains_itersym, i2 = subsymbol (i, declared_iter_sym, itersym)
405
- return :($ mod. vstore! ($ pA + $ i2 , $ B ))
405
+ return :($ mod. vstore! ($ pA, $ B , $ i2 ))
406
406
else
407
- return :($ mod. vstore! ($ pA + $ i , $ B ))
407
+ return :($ mod. vstore! ($ pA, $ B , $ i ))
408
408
end
409
409
elseif @capture (x, A_[i_,j_] = B_) || @capture (x, setindex! (A_, B_, i_, j_))
410
410
if A ∉ keys (indexed_expressions)
426
426
push! (loop_constants_quote. args, :( $ stridesym = $ stridexpr ))
427
427
loop_constants_dict[stridexpr] = stridesym
428
428
end
429
- return :($ mod. vstore! ($ pA + $ itersym + $ ej* $ stridesym, $ B ))
429
+ return :($ mod. vstore! ($ pA, $ B, $ itersym + $ ej* $ stridesym))
430
430
else
431
431
throw (" Indexing columns with vectorized loop variable is not supported." )
432
432
end
@@ -457,11 +457,11 @@ end
457
457
end
458
458
# # check to see if we are to do a vector load or a broadcast
459
459
if i == declared_iter_sym
460
- load_expr = :($ mod. vload ($ V, $ pA + $ itersym ))
460
+ load_expr = :($ mod. vload ($ V, $ pA, $ itersym ))
461
461
elseif isa (i, Expr)
462
462
contains_itersym, i2 = subsymbol (i, declared_iter_sym, itersym)
463
463
if contains_itersym
464
- load_expr = :($ mod. vload ($ V, $ pA + $ i2 ))
464
+ load_expr = :($ mod. vload ($ V, $ pA, $ i2 ))
465
465
else
466
466
load_expr = :($ mod. vbroadcast ($ V, $ pA - 1 + $ i))
467
467
end
497
497
push! (loop_constants_quote. args, :( $ stridesym = $ stridexpr ))
498
498
loop_constants_dict[stridexpr] = stridesym
499
499
end
500
- load_expr = :($ mod. vload ($ V, $ pA + $ itersym + $ ej* $ stridesym))
500
+ load_expr = :($ mod. vload ($ V, $ pA, $ itersym + $ ej* $ stridesym))
501
501
elseif j == declared_iter_sym
502
502
throw (" Indexing columns with vectorized loop variable is not supported." )
503
503
else
527
527
elseif @capture (x, A_[i_,:] .= B_)
528
528
# # Capture if there are multiple assignments...
529
529
if A ∉ keys (indexed_expressions)
530
- # pA = esc(gensym(A))
531
- # pA = esc(Symbol(:p,A))
532
530
pA = gensym (Symbol (:p ,A))
533
531
indexed_expressions[A] = pA
534
532
else
539
537
else
540
538
isym = i
541
539
end
542
-
543
540
br = gensym (:B )
544
541
br2 = gensym (:B )
545
542
coliter = gensym (:j )
546
-
547
543
stridexpr = :($ mod. LoopVectorization. stride_row ($ A))
548
544
if stridexpr ∈ keys (loop_constants_dict)
549
545
stridesym = loop_constants_dict[stridexpr]
@@ -552,32 +548,13 @@ end
552
548
push! (loop_constants_quote. args, :( $ stridesym = $ stridexpr ))
553
549
loop_constants_dict[stridexpr] = stridesym
554
550
end
555
- # numiterexpr = :(LoopVectorization.num_row_strides($A))
556
- # if numiterexpr ∈ keys(loop_constants_dict)
557
- # numitersym = loop_constants_dict[numiterexpr]
558
- # else
559
- # numitersym = gensym(:numiter)
560
- # push!(loop_constants_quote.args, :( $numitersym = $numiterexpr ))
561
- # loop_constants_dict[numiterexpr] = numitersym
562
- # end
563
-
564
551
expr = quote
565
552
$ br = $ mod. LoopVectorization. extract_data .($ B)
566
-
567
- # for $coliter ∈ 0:$numitersym-1
568
553
for $ coliter ∈ 0 : length ($ br)- 1
569
- @inbounds $ mod. vstore! ($ pA + $ isym + $ stridesym * $ coliter, getindex ( $ br, 1 + $ coliter) )
554
+ @inbounds $ mod. vstore! ($ pA, getindex ( $ br, 1 + $ coliter), $ isym + $ stridesym * $ coliter)
570
555
end
571
556
end
572
-
573
557
return expr
574
- # elseif @capture(x, @nexprs N_ ex_)
575
- # # println("Macroexpanding x:", x)
576
- # # @show ex
577
- # # mx = Expr(:escape, Expr(:block, Any[ Base.Cartesian.inlineanonymous(ex,i) for i = 1:N ]...))
578
- # mx = Expr(:block, Any[ Base.Cartesian.inlineanonymous(ex,i) for i = 1:N ]...)
579
- # # println("Macroexpanded x:", mx)
580
- # return mx
581
558
elseif @capture (x, zero (T_))
582
559
return :(zero ($ V))
583
560
elseif @capture (x, one (T_))
590
567
push! (main_body. args, :($ isymvec = $ mod. vadd ($ isymvec, vbroadcast ($ V, $ W)) ))
591
568
return isymvec
592
569
else
593
- # println("Returning x:", x)
594
570
return x
595
571
end
596
572
end , VectorizationDict, false , mod) # macro_escape = false
0 commit comments