Skip to content

Commit dd3afad

Browse files
authored
Merge pull request #180 from timholy/teh/inference
Some inferrability and precompile improvements
2 parents ad11641 + 57b715f commit dd3afad

File tree

7 files changed

+48
-43
lines changed

7 files changed

+48
-43
lines changed

src/add_compute.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function add_parent!(
5757
opp = getop(ls, var, elementbytes)
5858
# if var === :kern_1_1
5959
# @show operations(ls) ls.preamble_symsym
60-
# end
60+
# end
6161
# @show var opp first(operations(ls)) opp === first(operations(ls))
6262
if iscompute(opp) && instruction(opp).instr === :identity && length(loopdependencies(opp)) < position && isone(length(parents(opp))) && name(opp) === name(first(parents(opp)))
6363
first(parents(opp))
@@ -168,7 +168,7 @@ function add_reduction_update_parent!(
168168
# if !isouterreduction && !isreductzero(parent, ls, reduct_zero)
169169
add_reduct_instruct = !isouterreduction && !isconstant(parent)
170170
if add_reduct_instruct
171-
# We add
171+
# We add
172172
reductcombine = reduction_scalar_combine(instrclass)
173173
# reductcombine = :identity
174174
reductsym = gensym(:reduction)
@@ -222,7 +222,10 @@ function add_compute!(
222222
# instr = instruction(first(ex.args))::Symbol
223223
instr = instruction!(ls, first(ex.args))::Instruction
224224
args = @view(ex.args[2:end])
225-
(instr.instr === :(^) && length(args) == 2 && (args[2] isa Number)) && return add_pow!(ls, var, args[1], args[2], elementbytes, position)
225+
if instr.instr === :(^) && length(args) == 2
226+
arg2 = args[2]
227+
arg2 isa Number && return add_pow!(ls, var, args[1], arg2, elementbytes, position)
228+
end
226229
vparents = Operation[]
227230
deps = Symbol[]
228231
reduceddeps = Symbol[]
@@ -232,7 +235,7 @@ function add_compute!(
232235
if var === arg
233236
reduction_ind = ind
234237
# add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
235-
getop(ls, arg, elementbytes)
238+
getop(ls, arg::Symbol, elementbytes) # weird that this needs annotation
236239
elseif arg isa Expr
237240
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
238241
if isref
@@ -303,7 +306,7 @@ end
303306

304307
# adds x ^ (p::Real)
305308
function add_pow!(
306-
ls::LoopSet, var::Symbol, x, p::Real, elementbytes::Int, position::Int
309+
ls::LoopSet, var::Symbol, @nospecialize(x), p::Real, elementbytes::Int, position::Int
307310
)
308311
xop::Operation = if x isa Expr
309312
add_operation!(ls, gensym(:xpow), x, elementbytes, position)

src/add_loads.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ function maybeaddref!(ls::LoopSet, op)
22
ref = op.ref
33
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
44
# try to CSE
5-
if isnothing(id)
5+
if id === nothing
66
push!(ls.syms_aliasing_refs, name(op))
77
push!(ls.refs_aliasing_syms, ref)
88
0
99
else
1010
id
11-
end
11+
end
1212
end
1313

1414
function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true, broadcast::Bool = false)

src/constructors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function LoopSet(q::Expr, mod::Symbol = :Main)
6363
resize!(ls.loop_order, num_loops(ls))
6464
ls
6565
end
66-
LoopSet(q::Expr, m::Module) = LoopSet(macroexpand(m, q), Symbol(m))
66+
LoopSet(q::Expr, m::Module) = LoopSet(macroexpand(m, q)::Expr, Symbol(m))
6767

6868
"""
6969
@avx
@@ -143,6 +143,7 @@ use their `parent`. Triangular loops aren't yet supported.
143143
"""
144144
macro avx(q)
145145
q = macroexpand(__module__, q)
146+
isa(q, Expr) || return q
146147
q2 = if q.head === :for
147148
setup_call(LoopSet(q, __module__), q)
148149
else# assume broadcast

src/graphs.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function looprange(loop::Loop, incr::Int, mangledname)
136136
end
137137
function terminatecondition(
138138
loop::Loop, us::UnrollSpecification, n::Int, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
139-
)
139+
)
140140
if !isvectorized(us, n)
141141
looprange(loop, UF, mangledname)
142142
elseif inclmask
@@ -345,7 +345,7 @@ function LoopSet(mod::Symbol)
345345
Tuple{Int,Symbol}[],
346346
Tuple{Int,Int}[],
347347
Tuple{Int,Float64}[],
348-
Int[],Int[],
348+
Tuple{Int,NumberType}[],Tuple{Int,Symbol}[],
349349
Symbol[], Symbol[], Symbol[],
350350
ArrayReferenceMeta[],
351351
Matrix{Float64}(undef, 4, 2),
@@ -363,7 +363,7 @@ function cacheunrolled!(ls::LoopSet, u₁loop, u₂loop, vectorized)
363363
for opp parents(op)
364364
push!(children(opp), op)
365365
end
366-
end
366+
end
367367
end
368368

369369
num_loops(ls::LoopSet) = length(ls.loops)
@@ -404,7 +404,7 @@ function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int)
404404
end
405405
getop(ls::LoopSet, i::Int) = ls.operations[i]
406406

407-
# """
407+
# """
408408
# Returns an operation with the same name as `s`.
409409
# """
410410
# function getoperation(ls::LoopSet, s::Symbol)
@@ -486,15 +486,16 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
486486
itersym = (looprange.args[1])::Symbol
487487
r = looprange.args[2]
488488
if isexpr(r, :call)
489+
r = r::Expr # julia#37342
489490
f = first(r.args)
490491
loop::Loop = if f === :(:)
491492
lower = r.args[2]
492493
upper = r.args[3]
493494
lii::Bool = lower isa Integer
494-
liiv::Int = lii ? convert(Int, lower) : 1
495+
liiv::Int = lii ? convert(Int, lower::Integer) : 1
495496
uii::Bool = upper isa Integer
496497
if lii & uii # both are integers
497-
Loop(itersym, liiv, convert(Int, upper))
498+
Loop(itersym, liiv, convert(Int, upper::Integer)::Int)
498499
elseif lii # only lower bound is an integer
499500
if upper isa Symbol
500501
Loop(itersym, liiv, upper)
@@ -504,7 +505,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
504505
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
505506
end
506507
elseif uii # only upper bound is an integer
507-
uiiv = convert(Int, upper)
508+
uiiv = convert(Int, upper::Integer)::Int
508509
Loop(itersym, add_loop_bound!(ls, itersym, lower, false), uiiv)
509510
else # neither are integers
510511
L = add_loop_bound!(ls, itersym, lower, false)
@@ -544,16 +545,16 @@ end
544545
function register_loop!(ls::LoopSet, looprange::Expr)
545546
if looprange.head === :block # multiple loops
546547
for lr looprange.args
547-
register_single_loop!(ls, lr)
548+
register_single_loop!(ls, lr::Expr)
548549
end
549550
else
550551
@assert looprange.head === :(=)
551552
register_single_loop!(ls, looprange)
552553
end
553554
end
554555
function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int)
555-
register_loop!(ls, q.args[1])
556-
body = q.args[2]
556+
register_loop!(ls, q.args[1]::Expr)
557+
body = q.args[2]::Expr
557558
position = length(ls.loopsymbols)
558559
if body.head === :block
559560
add_block!(ls, body, elementbytes, position)
@@ -685,10 +686,8 @@ function add_operation!(
685686
end
686687
end
687688

688-
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Symbol, array, rawindices, elementbytes::Int, position::Int)
689-
add_store!(ls, RHS, array, rawindices, elementbytes)
690-
end
691-
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Expr, array, rawindices, elementbytes::Int, position::Int)
689+
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Union{Symbol,Expr}, array, rawindices, elementbytes::Int, position::Int)
690+
RHS isa Symbol && return add_store!(ls, RHS, array, rawindices, elementbytes)
692691
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
693692
cachedparents = copy(mpref.parents)
694693
ref = mpref.mref.ref
@@ -706,9 +705,9 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
706705
finex = first(ex.args)::Symbol
707706
if finex === :setindex!
708707
array, rawindices = ref_from_setindex!(ls, ex)
709-
prepare_rhs_for_storage!(ls, ex.args[3], array, rawindices, elementbytes, position)
708+
prepare_rhs_for_storage!(ls, ex.args[3]::Union{Symbol,Expr}, array, rawindices, elementbytes, position)
710709
else
711-
throw("Function $finex not recognized.")
710+
error("Function $finex not recognized.")
712711
end
713712
elseif ex.head === :(=)
714713
LHS = ex.args[1]
@@ -729,7 +728,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
729728
array, rawindices = ref_from_expr!(ls, LHS)
730729
prepare_rhs_for_storage!(ls, RHS, array, rawindices, elementbytes, position)
731730
else
732-
add_store_ref!(ls, RHS, LHS, elementbytes)
731+
add_store_ref!(ls, RHS, LHS, elementbytes) # is this necessary? (Extension API?)
733732
end
734733
elseif LHS.head === :tuple
735734
@assert length(LHS.args) 9 "Functions returning more than 9 values aren't currently supported."

src/precompile.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ function _precompile_()
1111
Base.precompile(Tuple{typeof(LoopVectorization.add_compute!),LoopVectorization.LoopSet,Symbol,Expr,Int64,Int64,Nothing})
1212
Base.precompile(Tuple{typeof(LoopVectorization.add_constant!),LoopVectorization.LoopSet,Float64,Vector{Symbol},Symbol,Int64})
1313
Base.precompile(Tuple{typeof(LoopVectorization.add_if!),LoopVectorization.LoopSet,Symbol,Expr,Int64,Int64,LoopVectorization.ArrayReferenceMetaPosition})
14+
Base.precompile(Tuple{typeof(avx_body), LoopSet, Tuple{Int8,Int8,Int8,Int}})
15+
Base.precompile(Tuple{typeof(avx_loopset), Vector{Instruction}, Vector{OperationStruct}, Vector{ArrayRefStruct}, Core.SimpleVector, Core.SimpleVector, Core.SimpleVector, Any})
1416
Base.precompile(Tuple{typeof(LoopVectorization.add_mref!),LoopVectorization.LoopSet,LoopVectorization.ArrayReferenceMeta,Int64,Type{VectorizationBase.StridedBitPointer{2, 1, 0, (1, 2), Tuple{ArrayInterface.StaticInt{1}, Int64}, Tuple{ArrayInterface.StaticInt{1}, ArrayInterface.StaticInt{1}}}},Symbol})
1517
Base.precompile(Tuple{typeof(LoopVectorization.add_mref!),LoopVectorization.LoopSet,LoopVectorization.ArrayReferenceMeta,Int64,Type{VectorizationBase.StridedPointer{Float32, 2, 1, 0, (1, 2), Tuple{ArrayInterface.StaticInt{4}, ArrayInterface.StaticInt{292}}, Tuple{ArrayInterface.StaticInt{1}, ArrayInterface.StaticInt{1}}}},Symbol})
1618
Base.precompile(Tuple{typeof(LoopVectorization.add_mref!),LoopVectorization.LoopSet,LoopVectorization.ArrayReferenceMeta,Int64,Type{VectorizationBase.StridedPointer{Float32, 2, 1, 0, (1, 2), Tuple{ArrayInterface.StaticInt{4}, ArrayInterface.StaticInt{308}}, Tuple{ArrayInterface.StaticInt{1}, ArrayInterface.StaticInt{1}}}},Symbol})

src/reconstruct_loopset.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ function add_mref!(
138138
return pushvarg!(ls, ar, i, name)
139139
end
140140
lic = copy(li);
141-
inds = getindices(ar); indsc = copy(inds);
141+
inds = getindices(ar); indsc = copy(inds);
142142
offsets = ar.ref.offsets; offsetsc = copy(offsets);
143143

144144
# must now sort array's inds, and stack pointer's
145145
tmpsp = gensym(name)
146146
pushvarg!(ls, ar, i, tmpsp)
147-
# pushpreamble!(ls,
147+
# pushpreamble!(ls,
148148
strd_tup = Expr(:tuple)
149149
offsets_tup = Expr(:tuple)
150150
for (i, p) enumerate(sp)
@@ -417,7 +417,8 @@ function sizeofeltypes(v, num_arrays)::Int
417417
# sizeof(T)
418418
end
419419

420-
function avx_loopset(instr, ops, arf, AM, LPSYM, LB, @nospecialize(vargs))
420+
function avx_loopset(instr::Vector{Instruction}, ops::Vector{OperationStruct}, arf::Vector{ArrayRefStruct},
421+
AM::Core.SimpleVector, LPSYM::Core.SimpleVector, LB::Core.SimpleVector, @nospecialize(vargs))
421422
ls = LoopSet(:LoopVectorization)
422423
num_arrays = length(arf)
423424
elementbytes = sizeofeltypes(vargs, num_arrays)
@@ -436,7 +437,7 @@ function avx_loopset(instr, ops, arf, AM, LPSYM, LB, @nospecialize(vargs))
436437
num_params = extract_external_functions!(ls, num_params)
437438
ls
438439
end
439-
function avx_body(ls, UNROLL)
440+
function avx_body(ls::LoopSet, UNROLL::Tuple{Int8,Int8,Int8,Int})
440441
inline, u₁, u₂, W = UNROLL
441442
ls.vector_width[] = W
442443
q = iszero(u₁) ? lower_and_split_loops(ls, inline % Int) : lower(ls, u₁ % Int, u₂ % Int, inline % Int)
@@ -448,7 +449,7 @@ function _avx_loopset_debug(::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM},
448449
@show OPS ARF AM LPSYM LB vargs
449450
_avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, typeof.(vargs))
450451
end
451-
function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, @nospecialize(vargs))
452+
function _avx_loopset(OPSsv::Core.SimpleVector, ARFsv::Core.SimpleVector, AMsv::Core.SimpleVector, LPSYMsv::Core.SimpleVector, LBsv::Core.SimpleVector, @nospecialize(vargs))
452453
nops = length(OPSsv) ÷ 3
453454
instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i 0:nops-1]
454455
ops = OperationStruct[ OPSsv[3i] for i 1:nops ]

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ end
4848
function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool = false)
4949
length(argv) < 3 && return length(call.args) == 4, cnmul, csub
5050
fun = first(argv)
51-
isadd = fun === :+ || fun === :vadd! || fun === :vadd || fun == :(Base.FastMath.add_fast)
52-
issub = fun === :- || fun === :vsub! || fun === :vsub || fun == :(Base.FastMath.sub_fast)
51+
isadd = fun === :+ || fun === :vadd! || fun === :vadd || (fun == :(Base.FastMath.add_fast))::Bool
52+
issub = fun === :- || fun === :vsub! || fun === :vsub || (fun == :(Base.FastMath.sub_fast))::Bool
5353
if !(isadd | issub)
5454
return length(call.args) == 4, cnmul, csub
5555
end
@@ -60,7 +60,7 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
6060
exa = ex.args
6161
f = first(exa)
6262
exav = @view(exa[2:end])
63-
if f === :* || f === :vmul! || f === :vmul || f == :(Base.FastMath.mul_fast)
63+
if f === :* || f === :vmul! || f === :vmul || (f == :(Base.FastMath.mul_fast))::Bool
6464
a, b = mulexpr(exav)
6565
call.args[2] = a
6666
call.args[3] = b
@@ -107,28 +107,27 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
107107
csub = false
108108
end
109109
return true, cnmul, csub
110-
end
110+
end
111111
end
112112
end
113113
end
114114
length(call.args) == 4, cnmul, csub
115115
end
116-
117-
function capture_muladd(ex::Expr, mod, LHS = nothing)
116+
117+
function capture_muladd(ex::Expr, mod)
118118
call = Expr(:call, Symbol(""), Symbol(""), Symbol(""))
119119
found, nmul, sub = recursive_muladd_search!(call, ex.args)
120120
found || return ex
121121
# a, b, c = call.args[2], call.args[3], call.args[4]
122122
# call.args[2], call.args[3], call.args[4] = c, a, b
123-
clobber = false#call.args[4] == LHS
124123
f = if nmul && sub
125-
clobber ? :vfnmsub231 : :vfnmsub
124+
:vfnmsub
126125
elseif nmul
127-
clobber ? :vfnmadd231 : :vfnmadd
126+
:vfnmadd
128127
elseif sub
129-
clobber ? :vfmsub231 : :vfmsub
128+
:vfmsub
130129
else
131-
clobber ? :vfmadd231 : :vfmadd
130+
:vfmadd
132131
end
133132
if mod === nothing
134133
call.args[1] = f
@@ -139,7 +138,7 @@ function capture_muladd(ex::Expr, mod, LHS = nothing)
139138
end
140139

141140
contract_pass!(::Any, ::Any) = nothing
142-
function contract!(expr::Expr, ex::Expr, i::Int, mod = nothing)
141+
function contract!(expr::Expr, ex::Expr, i::Int, mod)
143142
# if ex.head === :call
144143
# expr.args[i] = capture_muladd(ex, mod)
145144
if ex.head === :(+=)
@@ -163,7 +162,7 @@ function contract!(expr::Expr, ex::Expr, i::Int, mod = nothing)
163162
RHS = ex.args[2]
164163
# @show ex
165164
if RHS isa Expr && RHS.head === :call
166-
ex.args[2] = capture_muladd(RHS, mod, ex.args[1])
165+
ex.args[2] = capture_muladd(RHS, mod)
167166
end
168167
end
169168
contract_pass!(expr.args[i], mod)

0 commit comments

Comments
 (0)