Skip to content

Commit 5ecb9f7

Browse files
committed
Small improvements in inferrability
1 parent 5cbab62 commit 5ecb9f7

File tree

6 files changed

+41
-37
lines changed

6 files changed

+41
-37
lines changed

src/add_compute.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function add_parent!(
5757
opp = getop(ls, var, elementbytes)
5858
# if var === :kern_1_1
5959
# @show operations(ls) ls.preamble_symsym
60-
# end
60+
# end
6161
# @show var opp first(operations(ls)) opp === first(operations(ls))
6262
if iscompute(opp) && instruction(opp).instr === :identity && length(loopdependencies(opp)) < position && isone(length(parents(opp))) && name(opp) === name(first(parents(opp)))
6363
first(parents(opp))
@@ -168,7 +168,7 @@ function add_reduction_update_parent!(
168168
# if !isouterreduction && !isreductzero(parent, ls, reduct_zero)
169169
add_reduct_instruct = !isouterreduction && !isconstant(parent)
170170
if add_reduct_instruct
171-
# We add
171+
# We add
172172
reductcombine = reduction_scalar_combine(instrclass)
173173
# reductcombine = :identity
174174
reductsym = gensym(:reduction)
@@ -222,7 +222,10 @@ function add_compute!(
222222
# instr = instruction(first(ex.args))::Symbol
223223
instr = instruction!(ls, first(ex.args))::Instruction
224224
args = @view(ex.args[2:end])
225-
(instr.instr === :(^) && length(args) == 2 && (args[2] isa Number)) && return add_pow!(ls, var, args[1], args[2], elementbytes, position)
225+
if instr.instr === :(^) && length(args) == 2
226+
arg2 = args[2]
227+
arg2 isa Number && return add_pow!(ls, var, args[1], arg2, elementbytes, position)
228+
end
226229
vparents = Operation[]
227230
deps = Symbol[]
228231
reduceddeps = Symbol[]
@@ -232,7 +235,7 @@ function add_compute!(
232235
if var === arg
233236
reduction_ind = ind
234237
# add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
235-
getop(ls, arg, elementbytes)
238+
getop(ls, arg::Symbol, elementbytes) # weird that this needs annotation
236239
elseif arg isa Expr
237240
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
238241
if isref
@@ -303,7 +306,7 @@ end
303306

304307
# adds x ^ (p::Real)
305308
function add_pow!(
306-
ls::LoopSet, var::Symbol, x, p::Real, elementbytes::Int, position::Int
309+
ls::LoopSet, var::Symbol, @nospecialize(x), p::Real, elementbytes::Int, position::Int
307310
)
308311
xop::Operation = if x isa Expr
309312
add_operation!(ls, gensym(:xpow), x, elementbytes, position)

src/add_loads.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ function maybeaddref!(ls::LoopSet, op)
22
ref = op.ref
33
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
44
# try to CSE
5-
if isnothing(id)
5+
if id === nothing
66
push!(ls.syms_aliasing_refs, name(op))
77
push!(ls.refs_aliasing_syms, ref)
88
0
99
else
1010
id
11-
end
11+
end
1212
end
1313

1414
function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true, broadcast::Bool = false)

src/constructors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function LoopSet(q::Expr, mod::Symbol = :Main)
6363
resize!(ls.loop_order, num_loops(ls))
6464
ls
6565
end
66-
LoopSet(q::Expr, m::Module) = LoopSet(macroexpand(m, q), Symbol(m))
66+
LoopSet(q::Expr, m::Module) = LoopSet(macroexpand(m, q)::Expr, Symbol(m))
6767

6868
"""
6969
@avx
@@ -143,6 +143,7 @@ use their `parent`. Triangular loops aren't yet supported.
143143
"""
144144
macro avx(q)
145145
q = macroexpand(__module__, q)
146+
isa(q, Expr) || return q
146147
q2 = if q.head === :for
147148
setup_call(LoopSet(q, __module__), q)
148149
else# assume broadcast

src/graphs.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function looprange(loop::Loop, incr::Int, mangledname)
126126
end
127127
function terminatecondition(
128128
loop::Loop, us::UnrollSpecification, n::Int, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
129-
)
129+
)
130130
if !isvectorized(us, n)
131131
looprange(loop, UF, mangledname)
132132
elseif inclmask
@@ -335,7 +335,7 @@ function LoopSet(mod::Symbol)
335335
Tuple{Int,Symbol}[],
336336
Tuple{Int,Int}[],
337337
Tuple{Int,Float64}[],
338-
Int[],Int[],
338+
Tuple{Int,NumberType}[],Tuple{Int,Symbol}[],
339339
Symbol[], Symbol[], Symbol[],
340340
ArrayReferenceMeta[],
341341
Matrix{Float64}(undef, 4, 2),
@@ -353,7 +353,7 @@ function cacheunrolled!(ls::LoopSet, u₁loop, u₂loop, vectorized)
353353
for opp parents(op)
354354
push!(children(opp), op)
355355
end
356-
end
356+
end
357357
end
358358

359359
num_loops(ls::LoopSet) = length(ls.loops)
@@ -394,7 +394,7 @@ function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int)
394394
end
395395
getop(ls::LoopSet, i::Int) = ls.operations[i]
396396

397-
# """
397+
# """
398398
# Returns an operation with the same name as `s`.
399399
# """
400400
# function getoperation(ls::LoopSet, s::Symbol)
@@ -476,15 +476,16 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
476476
itersym = (looprange.args[1])::Symbol
477477
r = looprange.args[2]
478478
if isexpr(r, :call)
479+
r = r::Expr # julia#37342
479480
f = first(r.args)
480481
loop::Loop = if f === :(:)
481482
lower = r.args[2]
482483
upper = r.args[3]
483484
lii::Bool = lower isa Integer
484-
liiv::Int = lii ? convert(Int, lower) : 1
485+
liiv::Int = lii ? convert(Int, lower::Integer) : 1
485486
uii::Bool = upper isa Integer
486487
if lii & uii # both are integers
487-
Loop(itersym, liiv, convert(Int, upper))
488+
Loop(itersym, liiv, convert(Int, upper::Integer)::Int)
488489
elseif lii # only lower bound is an integer
489490
if upper isa Symbol
490491
Loop(itersym, liiv, upper)
@@ -494,7 +495,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
494495
Loop(itersym, liiv, add_loop_bound!(ls, itersym, upper, true))
495496
end
496497
elseif uii # only upper bound is an integer
497-
uiiv = convert(Int, upper)
498+
uiiv = convert(Int, upper::Integer)::Int
498499
Loop(itersym, add_loop_bound!(ls, itersym, lower, false), uiiv)
499500
else # neither are integers
500501
L = add_loop_bound!(ls, itersym, lower, false)
@@ -534,16 +535,16 @@ end
534535
function register_loop!(ls::LoopSet, looprange::Expr)
535536
if looprange.head === :block # multiple loops
536537
for lr looprange.args
537-
register_single_loop!(ls, lr)
538+
register_single_loop!(ls, lr::Expr)
538539
end
539540
else
540541
@assert looprange.head === :(=)
541542
register_single_loop!(ls, looprange)
542543
end
543544
end
544545
function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int)
545-
register_loop!(ls, q.args[1])
546-
body = q.args[2]
546+
register_loop!(ls, q.args[1]::Expr)
547+
body = q.args[2]::Expr
547548
position = length(ls.loopsymbols)
548549
if body.head === :block
549550
add_block!(ls, body, elementbytes, position)
@@ -675,10 +676,8 @@ function add_operation!(
675676
end
676677
end
677678

678-
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Symbol, array, rawindices, elementbytes::Int, position::Int)
679-
add_store!(ls, RHS, array, rawindices, elementbytes)
680-
end
681-
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Expr, array, rawindices, elementbytes::Int, position::Int)
679+
function prepare_rhs_for_storage!(ls::LoopSet, RHS::Union{Symbol,Expr}, array, rawindices, elementbytes::Int, position::Int)
680+
RHS isa Symbol && return add_store!(ls, RHS, array, rawindices, elementbytes)
682681
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
683682
cachedparents = copy(mpref.parents)
684683
ref = mpref.mref.ref
@@ -696,9 +695,9 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
696695
finex = first(ex.args)::Symbol
697696
if finex === :setindex!
698697
array, rawindices = ref_from_setindex!(ls, ex)
699-
prepare_rhs_for_storage!(ls, ex.args[3], array, rawindices, elementbytes, position)
698+
prepare_rhs_for_storage!(ls, ex.args[3]::Union{Symbol,Expr}, array, rawindices, elementbytes, position)
700699
else
701-
throw("Function $finex not recognized.")
700+
error("Function $finex not recognized.")
702701
end
703702
elseif ex.head === :(=)
704703
LHS = ex.args[1]
@@ -719,7 +718,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
719718
array, rawindices = ref_from_expr!(ls, LHS)
720719
prepare_rhs_for_storage!(ls, RHS, array, rawindices, elementbytes, position)
721720
else
722-
add_store_ref!(ls, RHS, LHS, elementbytes)
721+
add_store_ref!(ls, RHS, LHS, elementbytes) # is this necessary? (Extension API?)
723722
end
724723
elseif LHS.head === :tuple
725724
@assert length(LHS.args) 9 "Functions returning more than 9 values aren't currently supported."

src/reconstruct_loopset.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ function add_mref!(
138138
return pushvarg!(ls, ar, i, name)
139139
end
140140
lic = copy(li);
141-
inds = getindices(ar); indsc = copy(inds);
141+
inds = getindices(ar); indsc = copy(inds);
142142
offsets = ar.ref.offsets; offsetsc = copy(offsets);
143143

144144
# must now sort array's inds, and stack pointer's
145145
tmpsp = gensym(name)
146146
pushvarg!(ls, ar, i, tmpsp)
147-
# pushpreamble!(ls,
147+
# pushpreamble!(ls,
148148
strd_tup = Expr(:tuple)
149149
offsets_tup = Expr(:tuple)
150150
for (i, p) enumerate(sp)
@@ -417,7 +417,8 @@ function sizeofeltypes(v, num_arrays)::Int
417417
# sizeof(T)
418418
end
419419

420-
function avx_loopset(instr, ops, arf, AM, LPSYM, LB, @nospecialize(vargs))
420+
function avx_loopset(instr::Vector{Instruction}, ops::Vector{OperationStruct}, arf::Vector{ArrayRefStruct},
421+
AM::Core.SimpleVector, LPSYM::Core.SimpleVector, LB::Core.SimpleVector, @nospecialize(vargs))
421422
ls = LoopSet(:LoopVectorization)
422423
num_arrays = length(arf)
423424
elementbytes = sizeofeltypes(vargs, num_arrays)
@@ -436,7 +437,7 @@ function avx_loopset(instr, ops, arf, AM, LPSYM, LB, @nospecialize(vargs))
436437
num_params = extract_external_functions!(ls, num_params)
437438
ls
438439
end
439-
function avx_body(ls, UNROLL)
440+
function avx_body(ls::LoopSet, UNROLL::Tuple{Int8,Int8,Int8,Int})
440441
inline, u₁, u₂, W = UNROLL
441442
ls.vector_width[] = W
442443
q = iszero(u₁) ? lower_and_split_loops(ls, inline % Int) : lower(ls, u₁ % Int, u₂ % Int, inline % Int)
@@ -448,7 +449,7 @@ function _avx_loopset_debug(::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM},
448449
@show OPS ARF AM LPSYM LB vargs
449450
_avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, typeof.(vargs))
450451
end
451-
function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, @nospecialize(vargs))
452+
function _avx_loopset(OPSsv::Core.SimpleVector, ARFsv::Core.SimpleVector, AMsv::Core.SimpleVector, LPSYMsv::Core.SimpleVector, LBsv::Core.SimpleVector, @nospecialize(vargs))
452453
nops = length(OPSsv) ÷ 3
453454
instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i 0:nops-1]
454455
ops = OperationStruct[ OPSsv[3i] for i 1:nops ]

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ end
4848
function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool = false)
4949
length(argv) < 3 && return length(call.args) == 4, cnmul, csub
5050
fun = first(argv)
51-
isadd = fun === :+ || fun === :vadd! || fun === :vadd || fun == :(Base.FastMath.add_fast)
52-
issub = fun === :- || fun === :vsub! || fun === :vsub || fun == :(Base.FastMath.sub_fast)
51+
isadd = fun === :+ || fun === :vadd! || fun === :vadd || (fun == :(Base.FastMath.add_fast))::Bool
52+
issub = fun === :- || fun === :vsub! || fun === :vsub || (fun == :(Base.FastMath.sub_fast))::Bool
5353
if !(isadd | issub)
5454
return length(call.args) == 4, cnmul, csub
5555
end
@@ -60,7 +60,7 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
6060
exa = ex.args
6161
f = first(exa)
6262
exav = @view(exa[2:end])
63-
if f === :* || f === :vmul! || f === :vmul || f == :(Base.FastMath.mul_fast)
63+
if f === :* || f === :vmul! || f === :vmul || (f == :(Base.FastMath.mul_fast))::Bool
6464
a, b = mulexpr(exav)
6565
call.args[2] = a
6666
call.args[3] = b
@@ -107,14 +107,14 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
107107
csub = false
108108
end
109109
return true, cnmul, csub
110-
end
110+
end
111111
end
112112
end
113113
end
114114
length(call.args) == 4, cnmul, csub
115115
end
116-
117-
function capture_muladd(ex::Expr, mod, LHS = nothing)
116+
117+
function capture_muladd(ex::Expr, mod, @nospecialize(LHS) = nothing)
118118
call = Expr(:call, Symbol(""), Symbol(""), Symbol(""))
119119
found, nmul, sub = recursive_muladd_search!(call, ex.args)
120120
found || return ex
@@ -139,7 +139,7 @@ function capture_muladd(ex::Expr, mod, LHS = nothing)
139139
end
140140

141141
contract_pass!(::Any, ::Any) = nothing
142-
function contract!(expr::Expr, ex::Expr, i::Int, mod = nothing)
142+
function contract!(expr::Expr, ex::Expr, i::Int, mod)
143143
# if ex.head === :call
144144
# expr.args[i] = capture_muladd(ex, mod)
145145
if ex.head === :(+=)

0 commit comments

Comments
 (0)