Skip to content

Commit 48226d4

Browse files
committed
Cache rejectinterleave value, and update special function unroll checks
1 parent 0ae1af9 commit 48226d4

13 files changed

+214
-69
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ DocStringExtensions = "0.8"
2424
IfElse = "0.1"
2525
OffsetArrays = "1.4.1, 1.5"
2626
Requires = "1"
27-
SLEEFPirates = "0.6.11"
27+
SLEEFPirates = "0.6.12"
2828
Static = "0.2"
2929
ThreadingUtilities = "0.3"
3030
UnPack = "1"
31-
VectorizationBase = "0.19.3"
31+
VectorizationBase = "0.19.5"
3232
julia = "1.5"
3333

3434
[extras]

benchmark/looptests.jl

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,53 @@ function jgemm!(𝐂, 𝐀ᵀ::Adjoint, 𝐁ᵀ::Adjoint)
6464
end
6565
end
6666
function gemmavx!(𝐂, 𝐀, 𝐁)
67-
@avx for m axes(𝐀,1), n axes(𝐁,2)
67+
@avx for m indices((𝐀,𝐂),1), n indices((𝐁,𝐂),2)
68+
𝐂ₘₙ = zero(eltype(𝐂))
69+
for k indices((𝐀,𝐁),(2,1))
70+
𝐂ₘₙ += 𝐀[m,k] * 𝐁[k,n]
71+
end
72+
𝐂[m,n] = 𝐂ₘₙ
73+
end
74+
end
75+
function gemmavx!(Cc::AbstractMatrix{Complex{T}}, Ac::AbstractMatrix{Complex{T}}, Bc::AbstractMatrix{Complex{T}}) where {T}
76+
A = reinterpret(reshape, T, Ac)
77+
B = reinterpret(reshape, T, Bc)
78+
C = reinterpret(reshape, T, Cc)
79+
@avx for m indices((A,C),2), n indices((B,C),3)
80+
Cre = zero(T)
81+
Cim = zero(T)
82+
for k indices((A,B),(3,2))
83+
Cre += A[1,m,k]*B[1,k,n] - A[2,m,k]*B[2,k,n]
84+
Cim += A[1,m,k]*B[2,k,n] + A[2,m,k]*B[1,k,n]
85+
end
86+
C[1,m,n] = Cre
87+
C[2,m,n] = Cim
88+
end
89+
end
90+
function gemmavxt!(𝐂, 𝐀, 𝐁)
91+
@avxt for m axes(𝐀,1), n axes(𝐁,2)
6892
𝐂ₘₙ = zero(eltype(𝐂))
6993
for k axes(𝐀,2)
7094
𝐂ₘₙ += 𝐀[m,k] * 𝐁[k,n]
7195
end
7296
𝐂[m,n] = 𝐂ₘₙ
7397
end
7498
end
99+
function gemmavxt!(Cc::AbstractMatrix{Complex{T}}, Ac::AbstractMatrix{Complex{T}}, Bc::AbstractMatrix{Complex{T}}) where {T}
100+
A = reinterpret(reshape, T, Ac)
101+
B = reinterpret(reshape, T, Bc)
102+
C = reinterpret(reshape, T, Cc)
103+
@avxt for m indices((A,C),2), n indices((B,C),3)
104+
Cre = zero(T)
105+
Cim = zero(T)
106+
for k indices((A,B),(3,2))
107+
Cre += A[1,m,k]*B[1,k,n] - A[2,m,k]*B[2,k,n]
108+
Cim += A[1,m,k]*B[2,k,n] + A[2,m,k]*B[1,k,n]
109+
end
110+
C[1,m,n] = Cre
111+
C[2,m,n] = Cim
112+
end
113+
end
75114
function jdot(a, b)
76115
s = zero(eltype(a))
77116
# @inbounds @simd ivdep for i ∈ eachindex(a,b)
@@ -88,6 +127,14 @@ function jdotavx(a, b)
88127
end
89128
s
90129
end
130+
function jdotavxt(a, b)
131+
s = zero(eltype(a))
132+
# @avx for i ∈ eachindex(a,b)
133+
@avxt for i eachindex(a)
134+
s += a[i] * b[i]
135+
end
136+
s
137+
end
91138
function jselfdot(a)
92139
s = zero(eltype(a))
93140
@inbounds @simd ivdep for i eachindex(a)
@@ -324,3 +371,18 @@ function filter2dunrolledavx!(out::AbstractMatrix, A::AbstractMatrix, kern::Size
324371
end
325372
out
326373
end
374+
375+
376+
# function smooth_line!(sl,nrm1,j,i1,rl,ih2,denom)
377+
# @fastmath @inbounds @simd ivdep for i=i1:2:nrm1
378+
# sl[i,j]=denom*(rl[i,j]+ih2*(sl[i,j-1]+sl[i-1,j]+sl[i+1,j]+sl[i,j+1]))
379+
# end
380+
# end
381+
# function smooth_line_avx!(sl,nrm1,j,i1,sl,rl,ih2,denom)
382+
# @avx for i=i1:2:nrm1
383+
# sl[i,j]=denom*(rl[i,j]+ih2*(sl[i,j-1]+sl[i-1,j]+sl[i+1,j]+sl[i,j+1]))
384+
# end
385+
# end
386+
387+
388+

src/codegen/lower_compute.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11

22

33
function load_constrained(op, u₁loop, u₂loop, innermost_loop_or_vloop, forprefetch = false)
4-
loopdeps = loopdependencies(op)
5-
dependsonu₁ = u₁loop loopdeps
6-
dependsonu₂ = u₂loop loopdeps
4+
dependsonu₁ = isu₁unrolled(op)
5+
dependsonu₂ = isu₂unrolled(op)
76
if forprefetch
87
(dependsonu₁ & dependsonu₂) || return false
98
end

src/codegen/lower_load.jl

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function lower_load_no_optranslation!(
117117
u = ifelse(opu₁, u₁, 1)
118118
mvar = Symbol(variable_name(op, Core.ifelse(isu₂unrolled(op), suffix,-1)), '_', u)
119119
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
120-
if all(op.ref.loopedindex)
120+
if all(op.ref.loopedindex) && !rejectcurly(op)
121121
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset)
122122
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
123123
add_memory_mask!(loadexpr, op, td, mask)
@@ -325,29 +325,62 @@ end
325325
function _lower_load!(
326326
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool} = indices_calculated_by_pointer_offsets(ls, op.ref)
327327
)
328-
omop = offsetloadcollection(ls)
329-
batchid, opind = omop.batchedcollectionmap[identifier(op)]
330-
# @show batchid == 0 (!isvectorized(op)) rejectinterleave(ls, op, td.vloop, idsformap)
331-
if batchid == 0 || (!isvectorized(op)) || (rejectinterleave(ls, op, td.vloop, omop.batchedcollections[batchid]))
328+
if rejectinterleave(op)
332329
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
333-
elseif opind == 1# only lower loads once
334-
# I do not believe it is possible for `opind == 1` to be lowered after an operation depending on a different opind.
335-
# lower_load_collection!(q, ls, op, td, mask, collectionid)
330+
else
336331
omop = offsetloadcollection(ls)
337-
collectionid, copind = omop.opidcollectionmap[identifier(op)]
338-
opidmap = offsetloadcollection(ls).opids[collectionid]
339-
idsformap = omop.batchedcollections[batchid]
340-
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
332+
batchid, opind = omop.batchedcollectionmap[identifier(op)]
333+
if opind == 1
334+
collectionid, copind = omop.opidcollectionmap[identifier(op)]
335+
opidmap = offsetloadcollection(ls).opids[collectionid]
336+
idsformap = omop.batchedcollections[batchid]
337+
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
338+
end
341339
end
342340
end
343-
function addive_loopinductvar_only(op::Operation)
341+
function additive_vectorized_loopinductvar_only(op::Operation)
342+
isvectorized(op) || return true
344343
isloopvalue(op) && return true
345344
iscompute(op) || return false
346345
additive_instr = (:add_fast, :(+), :vadd, :identity, :sub_fast, :(-), :vsub)
347346
Base.sym_in(instruction(op).instr, additive_instr) || return false
348-
return all(addive_loopinductvar_only, parents(op))
347+
return all(additive_vectorized_loopinductvar_only, parents(op))
348+
end
349+
# Checks if we cannot use `Unroll`
350+
function rejectcurly(ls::LoopSet, op::Operation, td::UnrollArgs)
351+
@unpack u₁loopsym, vloopsym = td
352+
rejectcurly(ls, op, u₁loopsym, vloopsym)
353+
end
354+
function rejectcurly(ls::LoopSet, op::Operation, u₁loopsym::Symbol, vloopsym::Symbol)
355+
indices = getindicesonly(op)
356+
li = op.ref.loopedindex
357+
AV = AU = false
358+
for (n,ind) enumerate(indices)
359+
# @show AU, op, n, ind, vloopsym, u₁loopsym
360+
if li[n]
361+
if ind === vloopsym
362+
AV && return true
363+
AV = true
364+
end
365+
if ind === u₁loopsym
366+
AU && return true
367+
AU = true
368+
end
369+
else
370+
opp = findop(parents(op), ind)
371+
# @show opp
372+
if isvectorized(opp)
373+
AV && return true
374+
AV = true
375+
end
376+
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp))
377+
AU && return true
378+
AU = true
379+
end
380+
end
381+
end
382+
false
349383
end
350-
351384
function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true})
352385
vloopsym = vloop.itersymbol; strd = step(vloop)
353386
isknown(strd) || return true
@@ -356,7 +389,7 @@ function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::Su
356389
li && continue
357390
for indop operations(ls)
358391
if (name(indop) === ind) && isvectorized(indop)
359-
addive_loopinductvar_only(op) || return true # so that it is `MM`
392+
additive_vectorized_loopinductvar_only(indop) || return true # so that it is `MM`
360393
end
361394
end
362395
end

src/codegen/lower_memory_common.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,34 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
172172
vloopsym = vloop.itersymbol
173173
indices = getindicesonly(op)
174174
vstep = step(vloop)
175-
# loopedindex = op.ref.loopedindex
175+
li = op.ref.loopedindex
176176
# @assert all(loopedindex)
177177
# @unpack u₁, u₁loopsym, vloopsym = td
178178
# @show vptr(op), inds_calc_by_ptr_offset
179179
# isone(u₁) && return mem_offset_u(op, td, inds_calc_by_ptr_offset, true)
180180
AV = AU = -1
181181
for (n,ind) enumerate(indices)
182-
if ind === vloopsym
183-
AV = n
184-
end
185-
if ind === u₁loopsym
186-
AU = n
182+
# @show AU, op, n, ind, vloopsym, u₁loopsym
183+
if li[n]
184+
if ind === vloopsym
185+
@assert AV == -1 # FIXME: these asserts should be replaced with checks that prevent using `unrolled_curly` in these cases (also to be reflected in cost modeling, to avoid those)
186+
AV = n
187+
end
188+
if ind === u₁loopsym
189+
@assert AU == -1
190+
AU = n
191+
end
192+
else
193+
opp = findop(parents(op), ind)
194+
# @show opp
195+
if isvectorized(opp)
196+
@assert AV == -1
197+
AV = n
198+
end
199+
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp))
200+
@assert AU == -1
201+
AU = n
202+
end
187203
end
188204
end
189205
# if AU == -1

src/codegen/lower_store.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ function reduce_expr_u₂(toreduct::Symbol, instr::Instruction, u₂::Int)
2626
end
2727
Expr(:call, lv(:reduce_tup), reduce_to_onevecunroll(instr), t)
2828
end
29-
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int, isu₁unrolled::Bool)
30-
if u₂ != -1
29+
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int, isu₁unrolled::Bool, isu₂unrolled::Bool)
30+
# if u₂ == -1
31+
# u₁u, u₂u = (true, false)
32+
# else
33+
# u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, _Umax)
34+
# end
35+
if isu₂unrolled# u₂ != -1
3136
_toreduct = Symbol(toreduct, 0)
3237
push!(q.args, Expr(:(=), _toreduct, reduce_expr_u₂(toreduct, instr, u₂)))
3338
else
@@ -105,7 +110,7 @@ function lower_store!(
105110

106111
omop = offsetloadcollection(ls)
107112
batchid, opind = omop.batchedcollectionmap[identifier(op)]
108-
if ((batchid 0) && isvectorized(op)) && (!rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid]))
113+
if ((batchid 0) && isvectorized(op)) && (!rejectinterleave(op))
109114
(opind == 1) && lower_store_collection!(q, ls, op, ua, mask, inds_calc_by_ptr_offset)
110115
return
111116
end
@@ -173,10 +178,10 @@ end
173178

174179
function donot_tile_store(ls::LoopSet, op::Operation, vloop::Loop, reductfunc::Symbol, u₂::Int)
175180
(!((reductfunc === Symbol("")) && all(op.ref.loopedindex))) || (u₂ 1) || isconditionalmemop(op) && return true
176-
181+
rejectcurly(op) && return true
177182
omop = offsetloadcollection(ls)
178183
batchid, opind = omop.batchedcollectionmap[identifier(op)]
179-
return ((batchid 0) && isvectorized(op)) && (!rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid]))
184+
return ((batchid 0) && isvectorized(op)) && (!rejectinterleave(op))
180185
end
181186

182187
# VectorizationBase implements optimizations for certain grouped stores

src/codegen/lowering.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ function lower_block(
7373
end
7474
else
7575
# for u ∈ 0:u₁-1 # u₁ && !u₂
76-
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true)
76+
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, false)
77+
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, false, true)
7778
# end
7879
end
7980
if n > 1 && prepost == 1
@@ -434,7 +435,7 @@ end
434435

435436
# TODO: handle tiled outer reductions; they will require a suffix arg
436437
function initialize_outer_reductions!(
437-
q::Expr, op::Operation, _Umax::Int, vectorized::Symbol, us::UnrollSpecification, rs::Expr
438+
q::Expr, ls::LoopSet, op::Operation, _Umax::Int, vectorized::Symbol, us::UnrollSpecification, rs::Expr
438439
)
439440
@unpack u₁, u₂ = us
440441
Umax = u₂ == -1 ? _Umax : u₁
@@ -459,12 +460,18 @@ function initialize_outer_reductions!(
459460
Expr(:call, reduct_zero, typeTr)
460461
end
461462
mvar = variable_name(op, -1)
463+
# u1u, u2u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, u₂loop, u₂max)
462464
if u₂ == -1
463465
push!(q.args, Expr(:(=), Symbol(mvar, '_', _Umax), z))
464466
else
465-
for u 0:_Umax-1
466-
# push!(q.args, Expr(:(=), Symbol(mvar, '_', u), z))
467-
push!(q.args, Expr(:(=), Symbol(mvar, u), z))
467+
u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, u₂)
468+
if u₁u
469+
push!(q.args, Expr(:(=), Symbol(mvar, '_', _Umax), z))
470+
else
471+
for u 0:_Umax-1
472+
# push!(q.args, Expr(:(=), Symbol(mvar, '_', u), z))
473+
push!(q.args, Expr(:(=), Symbol(mvar, u), z))
474+
end
468475
end
469476
end
470477
nothing
@@ -473,7 +480,7 @@ function initialize_outer_reductions!(q::Expr, ls::LoopSet, Umax::Int, vectorize
473480
rs = staticexpr(reg_size(ls))
474481
us = ls.unrollspecification[]
475482
for or ls.outer_reductions
476-
initialize_outer_reductions!(q, ls.operations[or], Umax, vectorized, us, rs)
483+
initialize_outer_reductions!(q, ls, ls.operations[or], Umax, vectorized, us, rs)
477484
end
478485
end
479486
initialize_outer_reductions!(ls::LoopSet, Umax::Int, vectorized::Symbol) = initialize_outer_reductions!(ls.preamble, ls, Umax, vectorized)
@@ -529,18 +536,21 @@ end
529536
## This performs reduction to one `Vec`
530537
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
531538
us = ls.unrollspecification[]
532-
u1f, u2f = if us.u₂ == -1 # TODO: these multiple meanings make code hard to follow. Simplify.
539+
u₁f, u₂f = if us.u₂ == -1 # TODO: these multiple meanings make code hard to follow. Simplify.
533540
ifelse(U == -1, us.u₁, U), -1
534541
else
535542
us.u₁, U
536543
end
537544
# u₁loop, u₂loop = getunrolled(ls)
545+
u₁loop = getloop(ls, us.u₁loopnum).itersymbol
546+
u₂loop = getloop(ls, us.u₂loopnum).itersymbol
538547
for or ls.outer_reductions
539548
op = ls.operations[or]
540549
var = name(op)
541550
mvar = mangledvar(op)
542551
instr = instruction(op)
543-
reduce_expr!(q, mvar, instr, u1f, u2f, isu₁unrolled(op))
552+
u₁u, u₂u = isunrolled_sym(op, u₁loop, u₂loop, u₂f)
553+
reduce_expr!(q, mvar, instr, u₁f, u₂f, u₁u, u₂u)#isu₁unrolled(op))
544554
if !iszero(length(ls.opdict))
545555
if (isu₁unrolled(op) | isu₂unrolled(op))
546556
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Symbol(mvar, "##onevec##"), var)))

0 commit comments

Comments
 (0)