@@ -367,45 +367,57 @@ end
367
367
368
368
369
369
isconditionalmemop (op:: Operation ) = (instruction (op). instr === :conditionalload ) || (instruction (op). instr === :conditionalstore! )
370
- function add_memory_mask! (memopexpr:: Expr , op:: Operation , td:: UnrollArgs , mask:: Bool , ls:: LoopSet )
371
- @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
372
- if isconditionalmemop (op)
373
- condop = last (parents (op))
374
- opu₂ = (suffix ≠ - 1 ) && isu₂unrolled (op)
375
- condvar, condu₁unrolled = condvarname_and_unroll (condop, u₁loopsym, u₂loopsym, vloopsym, suffix, opu₂, ls)
376
- # if it isn't unrolled, then `m`
377
- u = condu₁unrolled ? u₁ : 1
378
- # u = isu₁unrolled(condop) ? u₁ : 1
379
- condvar = Symbol (condvar, ' _' , u)
380
- # If we need to apply `MASKSYMBOL` and the condvar
381
- # 2 condvar possibilities:
382
- # `VecUnroll` applied everywhere
383
- # single mask "broadcast"
384
- # 2 mask possibilities
385
- # u₁loopsym ≠ vloopsym, and we mask all
386
- # u₁loopsym == vloopsym, and we mask last
387
- # broadcast both, so can do so implicitly
388
- # this is true whether or not `condbroadcast`
389
- if ! mask || (! isvectorized (op))
390
- push! (memopexpr . args, condvar )
391
- elseif (u₁loopsym ≢ vloopsym) | (u₁ == 1 ) # mask all equivalenetly
392
- push! (memopexpr . args, Expr ( :call , lv (: & ), condvar, MASKSYMBOL))
393
- # if the condition `(u₁loopsym ≢ vloopsym) | (u₁ == 1)` failed, we need to apply `MASKSYMBOL` only to last unroll.
394
- elseif ! condu₁unrolled && isu₁unrolled (op) # condbroadcast
395
- # explicitly broadcast `condvar`, and apply `MASKSYMBOL` to end
396
- t = Expr (:call , lv (:promote ))
397
- for um ∈ 1 : u₁ - 1
398
- push! (t . args, condvar)
399
- end
400
- push! (t . args, Expr ( :call , lv (: & ), condvar, MASKSYMBOL))
401
- push! (memopexpr . args, Expr (:call , lv (:VecUnroll ), t ))
402
- else # !condbroadcast && !vecunrolled
403
- push! (memopexpr . args, Expr ( :call , lv ( :and_last ), condvar, MASKSYMBOL) )
370
+ function add_memory_mask! (memopexpr:: Expr , op:: Operation , td:: UnrollArgs , mask:: Bool , ls:: LoopSet , u₁ᵢ :: Int )
371
+ @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
372
+ if isconditionalmemop (op)
373
+ condop = last (parents (op))
374
+ opu₂ = (suffix ≠ - 1 ) && isu₂unrolled (op)
375
+ condvar, condu₁unrolled = condvarname_and_unroll (condop, u₁loopsym, u₂loopsym, vloopsym, suffix, opu₂, ls)
376
+ # if it isn't unrolled, then `m`
377
+ u = condu₁unrolled ? u₁ : 1
378
+ # u = isu₁unrolled(condop) ? u₁ : 1
379
+ condvar = Symbol (condvar, ' _' , u)
380
+ # If we need to apply `MASKSYMBOL` and the condvar
381
+ # 2 condvar possibilities:
382
+ # `VecUnroll` applied everywhere
383
+ # single mask "broadcast"
384
+ # 2 mask possibilities
385
+ # u₁loopsym ≠ vloopsym, and we mask all
386
+ # u₁loopsym == vloopsym, and we mask last
387
+ # broadcast both, so can do so implicitly
388
+ # this is true whether or not `condbroadcast`
389
+ if ! mask || (! isvectorized (op))
390
+ if u₁ᵢ == 0 | (u == 1 )
391
+ push! (memopexpr . args, condvar)
392
+ else
393
+ push! (memopexpr . args, :( $ getfield ( $ getfield ( $ condvar, 1 ), $ (u₁ᵢ), false )))
394
+ end
395
+ elseif (u₁loopsym ≢ vloopsym) | (u₁ == 1 ) # mask all equivalenetly
396
+ push! (memopexpr . args, Expr (:call , lv (:& ), condvar, MASKSYMBOL ))
397
+ # if the condition `(u₁loopsym ≢ vloopsym) | (u₁ == 1)` failed, we need to apply `MASKSYMBOL` only to last unroll.
398
+ elseif (( ! condu₁unrolled)) && isu₁unrolled (op) # condbroadcast
399
+ if u₁ᵢ == 0
400
+ # explicitly broadcast `condvar`, and apply ` MASKSYMBOL` to end
401
+ t = Expr (:call , lv (:promote ))
402
+ for um ∈ 1 : u₁ - 1
403
+ push! (t . args, condvar)
404
404
end
405
- elseif mask && isvectorized (op)
406
- push! (memopexpr. args, MASKSYMBOL)
405
+ push! (t. args, Expr (:call , lv (:& ), condvar, MASKSYMBOL))
406
+ push! (memopexpr. args, Expr (:call , lv (:VecUnroll ), t))
407
+ else
408
+ push! (memopexpr. args, condvar)
409
+ end
410
+ elseif u₁i == 0 # !condbroadcast && !vecunrolled
411
+ push! (memopexpr. args, Expr (:call , lv (:and_last ), condvar, MASKSYMBOL))
412
+ elseif u₁i == u₁ # mask
413
+ push! (memopexpr. args, Expr (:call , lv (:& ), :($ getfield ($ getfield (condvar,1 ),$ u₁i,false )), MASKSYMBOL))
414
+ else
415
+ push! (memopexpr. args, Expr (:call , lv (:& ), :($ getfield ($ getfield (condvar,1 ),$ u₁i,false ))))
407
416
end
408
- nothing
417
+ elseif mask && isvectorized (op)
418
+ push! (memopexpr. args, MASKSYMBOL)
419
+ end
420
+ nothing
409
421
end
410
422
411
423
# varassignname(var::Symbol, u::Int, isunrolled::Bool) = isunrolled ? Symbol(var, u) : var
0 commit comments