Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
debug
Manifest.toml
103 changes: 52 additions & 51 deletions src/infer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,70 +182,71 @@ function openbranches(bl)
end

function step!(inf::Inference)
frame, b, f, ip = pop!(inf.queue)
block, stmts = getblock(frame, b, f)
uninline!(frame, b, f)
if ip <= length(stmts)
var = stmts[ip]
st = block[var]
if isexpr(st.expr, :call)
T = infercall!(inf, (frame, b, f, ip), block, st.expr)
if T != Union{}
block.ir[var] = stmt(block[var], type = _union(st.type, T))
push!(inf.queue, (frame, b, f, ip+1))
end
elseif isexpr(st.expr, :inbounds)
push!(inf.queue, (frame, b, f, ip+1))
frame, b, f, ip = pop!(inf.queue)
block, stmts = getblock(frame, b, f)
uninline!(frame, b, f)
if ip <= length(stmts)
var = stmts[ip]
st = block[var]
st.expr isa QuoteNode && return
if isexpr(st.expr, :call)
T = infercall!(inf, (frame, b, f, ip), block, st.expr)
if T != Union{}
block.ir[var] = stmt(block[var], type = _union(st.type, T))
push!(inf.queue, (frame, b, f, ip+1))
end
elseif isexpr(st.expr, :inbounds)
push!(inf.queue, (frame, b, f, ip+1))
else
error("Unrecognised expression $(st.expr)")
end
elseif (brs = openbranches(block); length(brs) == 1 && !isreturn(brs[1])
&& !(brs[1].block == length(frame.ir.blocks)))
inferbranch!(inf, frame, b, f, brs[1])
else
error("Unrecognised expression $(st.expr)")
end
elseif (brs = openbranches(block); length(brs) == 1 && !isreturn(brs[1])
&& !(brs[1].block == length(frame.ir.blocks)))
inferbranch!(inf, frame, b, f, brs[1])
else
for br in brs
if isreturn(br)
T = exprtype(block.ir, IRTools.returnvalue(block))
_issubtype(T, frame.rettype) && return
frame.rettype = _union(frame.rettype, T)
foreach(loc -> push!(inf.queue, loc), frame.edges)
else
args = exprtype.((block.ir,), arguments(br))
if blockargs!(IRTools.block(frame.ir, br.block), args)
push!(inf.queue, (frame, br.block, 0, 1))
for br in brs
if isreturn(br)
T = exprtype(block.ir, IRTools.returnvalue(block))
_issubtype(T, frame.rettype) && return
frame.rettype = _union(frame.rettype, T)
foreach(loc -> push!(inf.queue, loc), frame.edges)
else
args = exprtype.((block.ir,), arguments(br))
if blockargs!(IRTools.block(frame.ir, br.block), args)
push!(inf.queue, (frame, br.block, 0, 1))
end
end
end
end
end
end
return
return
end

function infer!(inf::Inference)
while !isempty(inf.queue)
step!(inf)
end
for (_, fr) in inf.frames
inline_bbs!(fr)
fr.ir |> IRTools.Inner.trimblocks!
end
return inf
while !isempty(inf.queue)
step!(inf)
end
for (_, fr) in inf.frames
inline_bbs!(fr)
fr.ir |> IRTools.Inner.trimblocks!
end
return inf
end

function Inference(fr::Frame, P)
q = WorkQueue{Any}()
push!(q, (fr, 1, 0, 1))
Inference(Dict(argtypes(fr.ir)=>fr), q, IdDict(), P)
q = WorkQueue{Any}()
push!(q, (fr, 1, 0, 1))
Inference(Dict(argtypes(fr.ir)=>fr), q, IdDict(), P)
end

function infer!(P, ir::IR, args...)
fr = frame(ir, args...)
inf = Inference(fr, P)
infer!(inf)
fr = frame(ir, args...)
inf = Inference(fr, P)
infer!(inf)
end

function return_type(ir::IR, args...)
fr = frame(copy(ir), args...)
inf = Inference(fr, Defaults())
infer!(inf)
return fr.rettype
fr = frame(copy(ir), args...)
inf = Inference(fr, Defaults())
infer!(inf)
return fr.rettype
end
3 changes: 1 addition & 2 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,4 @@ struct KwFunc{F} end

@abstract Basic Core.kwfunc(::T) where T = Const(KwFunc{T}())

instead(::Basic, args, ::AType{KwFunc{F}}, kw, f, xs...) where F =
args, (Core.kwftype(widen(f)), kw, f, xs...)
instead(::Basic, args, ::AType{KwFunc{F}}, kw, f, xs...) where F = (args, (Core.kwftype(widen(f)), kw, f, xs...))
4 changes: 2 additions & 2 deletions src/lib/numeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ struct Numeric end
asin, asind, asinh, atan, atand, atanh, cbrt, conj, cos, cosd, cosh,
cospi, cot, cotd, coth, csc, cscd, csch, deg2rad, exp, exp10, exp2,
expm1, float, inv, log, log10, log1p, log2, rad2deg, sec, secd, sech,
sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, trailing_zeros,
>>, <<, unsigned, rem, //
sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, trailing_zeros, >>, <<, unsigned, rem, //

@abstract Numeric rand() = Float64
@abstract Numeric randn() = Float64
@abstract Numeric rand(::Type{Bool}) where T = Bool

@abstract Numeric sum(xs::Array{T,N}; dims = :) where {T,N} =
dims == Const(:) ? T : Array{T,N}

9 changes: 9 additions & 0 deletions test/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ tr = @trace pow(2, 3)
tr = @trace pow(::Int, ::Int)
@test returntype(tr) == Int

function m(n)
for i in 1 : n
nothing
end
end

tr = @trace m(::Int)
@test returntype(tr) == Const(nothing)

function sumabs2(xs)
s = zero(eltype(xs))
for i = 1:length(xs)
Expand Down