diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2ffbad6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +debug +Manifest.toml diff --git a/src/infer.jl b/src/infer.jl index 9edc3ba..e9f7223 100644 --- a/src/infer.jl +++ b/src/infer.jl @@ -182,70 +182,71 @@ function openbranches(bl) end function step!(inf::Inference) - frame, b, f, ip = pop!(inf.queue) - block, stmts = getblock(frame, b, f) - uninline!(frame, b, f) - if ip <= length(stmts) - var = stmts[ip] - st = block[var] - if isexpr(st.expr, :call) - T = infercall!(inf, (frame, b, f, ip), block, st.expr) - if T != Union{} - block.ir[var] = stmt(block[var], type = _union(st.type, T)) - push!(inf.queue, (frame, b, f, ip+1)) - end - elseif isexpr(st.expr, :inbounds) - push!(inf.queue, (frame, b, f, ip+1)) + frame, b, f, ip = pop!(inf.queue) + block, stmts = getblock(frame, b, f) + uninline!(frame, b, f) + if ip <= length(stmts) + var = stmts[ip] + st = block[var] + st.expr isa QuoteNode && return + if isexpr(st.expr, :call) + T = infercall!(inf, (frame, b, f, ip), block, st.expr) + if T != Union{} + block.ir[var] = stmt(block[var], type = _union(st.type, T)) + push!(inf.queue, (frame, b, f, ip+1)) + end + elseif isexpr(st.expr, :inbounds) + push!(inf.queue, (frame, b, f, ip+1)) + else + error("Unrecognised expression $(st.expr)") + end + elseif (brs = openbranches(block); length(brs) == 1 && !isreturn(brs[1]) + && !(brs[1].block == length(frame.ir.blocks))) + inferbranch!(inf, frame, b, f, brs[1]) else - error("Unrecognised expression $(st.expr)") - end - elseif (brs = openbranches(block); length(brs) == 1 && !isreturn(brs[1]) - && !(brs[1].block == length(frame.ir.blocks))) - inferbranch!(inf, frame, b, f, brs[1]) - else - for br in brs - if isreturn(br) - T = exprtype(block.ir, IRTools.returnvalue(block)) - _issubtype(T, frame.rettype) && return - frame.rettype = _union(frame.rettype, T) - foreach(loc -> push!(inf.queue, loc), frame.edges) - else - args = exprtype.((block.ir,), arguments(br)) - if blockargs!(IRTools.block(frame.ir, br.block), args) - push!(inf.queue, (frame, br.block, 0, 1)) + for br in brs + if isreturn(br) + T = exprtype(block.ir, IRTools.returnvalue(block)) + _issubtype(T, frame.rettype) && return + frame.rettype = _union(frame.rettype, T) + foreach(loc -> push!(inf.queue, loc), frame.edges) + else + args = exprtype.((block.ir,), arguments(br)) + if blockargs!(IRTools.block(frame.ir, br.block), args) + push!(inf.queue, (frame, br.block, 0, 1)) + end + end end - end end - end - return + return end function infer!(inf::Inference) - while !isempty(inf.queue) - step!(inf) - end - for (_, fr) in inf.frames - inline_bbs!(fr) - fr.ir |> IRTools.Inner.trimblocks! - end - return inf + while !isempty(inf.queue) + step!(inf) + end + for (_, fr) in inf.frames + inline_bbs!(fr) + fr.ir |> IRTools.Inner.trimblocks! + end + return inf end function Inference(fr::Frame, P) - q = WorkQueue{Any}() - push!(q, (fr, 1, 0, 1)) - Inference(Dict(argtypes(fr.ir)=>fr), q, IdDict(), P) + q = WorkQueue{Any}() + push!(q, (fr, 1, 0, 1)) + Inference(Dict(argtypes(fr.ir)=>fr), q, IdDict(), P) end function infer!(P, ir::IR, args...) - fr = frame(ir, args...) - inf = Inference(fr, P) - infer!(inf) + fr = frame(ir, args...) + inf = Inference(fr, P) + infer!(inf) end function return_type(ir::IR, args...) - fr = frame(copy(ir), args...) - inf = Inference(fr, Defaults()) - infer!(inf) - return fr.rettype + fr = frame(copy(ir), args...) + inf = Inference(fr, Defaults()) + infer!(inf) + return fr.rettype end diff --git a/src/lib/base.jl b/src/lib/base.jl index 57b714c..d179713 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -52,5 +52,4 @@ struct KwFunc{F} end @abstract Basic Core.kwfunc(::T) where T = Const(KwFunc{T}()) -instead(::Basic, args, ::AType{KwFunc{F}}, kw, f, xs...) where F = - args, (Core.kwftype(widen(f)), kw, f, xs...) +instead(::Basic, args, ::AType{KwFunc{F}}, kw, f, xs...) where F = (args, (Core.kwftype(widen(f)), kw, f, xs...)) diff --git a/src/lib/numeric.jl b/src/lib/numeric.jl index 84c1e22..80a4eec 100644 --- a/src/lib/numeric.jl +++ b/src/lib/numeric.jl @@ -7,8 +7,7 @@ struct Numeric end asin, asind, asinh, atan, atand, atanh, cbrt, conj, cos, cosd, cosh, cospi, cot, cotd, coth, csc, cscd, csch, deg2rad, exp, exp10, exp2, expm1, float, inv, log, log10, log1p, log2, rad2deg, sec, secd, sech, - sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, trailing_zeros, - >>, <<, unsigned, rem, // + sin, sind, sinh, sinpi, sqrt, tan, tand, tanh, transpose, trailing_zeros, >>, <<, unsigned, rem, // @abstract Numeric rand() = Float64 @abstract Numeric randn() = Float64 @@ -16,3 +15,4 @@ struct Numeric end @abstract Numeric sum(xs::Array{T,N}; dims = :) where {T,N} = dims == Const(:) ? T : Array{T,N} + diff --git a/test/trace.jl b/test/trace.jl index 85ef706..ce5d304 100644 --- a/test/trace.jl +++ b/test/trace.jl @@ -145,6 +145,15 @@ tr = @trace pow(2, 3) tr = @trace pow(::Int, ::Int) @test returntype(tr) == Int +function m(n) + for i in 1 : n + nothing + end +end + +tr = @trace m(::Int) +@test returntype(tr) == Const(nothing) + function sumabs2(xs) s = zero(eltype(xs)) for i = 1:length(xs)