Skip to content

Commit 04689bb

Browse files
committed
enable toplevel optimization and eliminate some toplevel special casings
Better to work with <JuliaLang/julia#42013>, but I also added an hacky fallback that makes use of the existing method definition pipeline.
1 parent 6591201 commit 04689bb

File tree

6 files changed

+50
-49
lines changed

6 files changed

+50
-49
lines changed

src/JET.jl

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -842,30 +842,32 @@ function report_text(text::AbstractString,
842842
return JETToplevelResult(analyzer′, res, source; analyzer, jetconfigs...)
843843
end
844844

845+
# we have to go on hacks; see `transform_abstract_global_symbols!` and `resolve_toplevel_symbols`
845846
function analyze_toplevel!(analyzer::AbstractAnalyzer, src::CodeInfo)
846847
# construct toplevel `MethodInstance`
847848
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
848-
mi.uninferred = src
849849
mi.specTypes = Tuple{}
850850

851-
transform_abstract_global_symbols!(analyzer, src)
852-
mi.def = get_toplevelmod(analyzer)
851+
mi.def = mod = get_toplevelmod(analyzer)
852+
src = transform_abstract_global_symbols!(analyzer, src)
853+
src = resolve_toplevel_symbols(mod, src)
854+
mi.uninferred = src
853855

854856
result = InferenceResult(mi);
855-
# toplevel frame doesn't need to be cached (and so it won't be optimized), nor should
856-
# go through JET's code generation error check
857-
frame = InferenceState(result, src, #=cached=# false, analyzer);
857+
# toplevel frames don't really need to be cached, but still better to be optimized
858+
# in order to get reasonable `LocalUndefVarErrorReport` and `UncaughtExceptionReport`
859+
frame = InferenceState(result, src, #=cached=# true, analyzer);
858860

859861
return analyze_frame!(analyzer, frame)
860862
end
861863

862-
# HACK this is an native hack to re-use `AbstractInterpreter`'s approximated slot types for
864+
# HACK this is very naive hack to re-use `AbstractInterpreter`'s slot type approximation for
863865
# assignments of abstract global variables, which are represented as toplevel symbols at this point;
864-
# the idea is just to transform them into slots from symbols and use their approximated type
865-
# on their assignment.
866+
# the idea is just to transform them into slot from symbol and use their approximated type
867+
# on their assignment (see `finish(::InferenceState, ::AbstractAnalyzer)`).
866868
# NOTE that `transform_abstract_global_symbols!` will produce really invalid code for
867869
# actual interpretation or execution, but all the statements won't be interpreted anymore
868-
# by `ConcreteInterpreter` nor executed anyway since toplevel frames aren't cached
870+
# by `ConcreteInterpreter` nor executed by the native compilation pipeline anyway
869871
function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::CodeInfo)
870872
nslots = length(src.slotnames)
871873
abstrct_global_variables = Dict{Symbol,Int}()
@@ -901,6 +903,35 @@ function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::Cod
901903
end
902904

903905
set_global_slots!(analyzer, Dict(idx => slotname for (slotname, idx) in abstrct_global_variables))
906+
907+
return src
908+
end
909+
910+
# resolve toplevel symbols (and other expressions like `:foreigncall`)
911+
# so that the returned `CodeInfo` is eligible for abstractintepret and optimization
912+
@static if VERSION v"1.8.0-DEV.420"
913+
function resolve_toplevel_symbols(mod::Module, src::CodeInfo)
914+
newsrc = copy(src)
915+
@ccall jl_resolve_globals_in_ir(newsrc.code::Any, mod::Any, svec()::Any, 1::Any)::Cvoid
916+
return newsrc
917+
end
918+
else
919+
# HACK before https://github.com/JuliaLang/julia/pull/42013, we need to go through
920+
# the method definition pipeline to get the effect of `jl_resolve_globals_in_ir`
921+
function resolve_toplevel_symbols(mod::Module, src::CodeInfo)
922+
sig = Core.svec(
923+
svec(typeof(__toplevelf__)),
924+
svec(),
925+
QuoteNode(LineNumberNode(@__LINE__, @__FILE__)))
926+
# branching on https://github.com/JuliaLang/julia/pull/41137
927+
method = (@static if isdefined(Core.Compiler, :OverlayMethodTable)
928+
ccall(:jl_method_def, Any, (Any, Ptr{Cvoid}, Any, Any), sig, C_NULL, body, moduleof(frame))
929+
else
930+
ccall(:jl_method_def, Any, (Any, Any, Any), sig, body, moduleof(frame))
931+
end)::Method
932+
return CC.uncompressed_ir(method)
933+
end
934+
function __toplevelf__ end
904935
end
905936

906937
# TODO `analyze_builtin!` ?

src/abstractinterpretation.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,6 @@ function CC.abstract_eval_special_value(analyzer::AbstractAnalyzer, @nospecializ
385385
# if it's really not defined, the error will be generated later anyway
386386
e = GlobalRef(get_toplevelmod(analyzer), get_slotname(sv, e))
387387
end
388-
elseif isa(e, Symbol)
389-
# (already concretized) toplevel global symbols
390-
e = GlobalRef(get_toplevelmod(analyzer), e)
391388
end
392389
end
393390

@@ -749,7 +746,7 @@ function is_constant_declared(name::Symbol, sv::InferenceState)
749746
return any(sv.src.code) do @nospecialize(x)
750747
if @isexpr(x, :const)
751748
arg = first(x.args)
752-
# `transform_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s
749+
# `transform_abstract_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s
753750
if isa(arg, Slot)
754751
return get_slotname(sv, arg) === name
755752
end

src/analyzer.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,7 @@ function maybe_initialize_caches!(analyzer::AbstractAnalyzer)
482482
end
483483

484484
# check if we're in a toplevel module
485-
@inline istoplevel(sv::InferenceState) = istoplevel(sv.linfo)
486-
@inline istoplevel(::OptimizationState) = false # optimization never happen for top-level code
485+
@inline istoplevel(sv::State) = istoplevel(sv.linfo)
487486
@inline istoplevel(linfo::MethodInstance) = isa(linfo.def, Module)
488487

489488
is_global_slot(analyzer::AbstractAnalyzer, slot::Int) = slot in keys(get_global_slots(analyzer))

src/locinfo.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,7 @@ function _get_sig_type((sv, _)::StateAtPC, arg::Argument)
160160
return Any[sig, typ], typ
161161
end
162162
_get_sig_type(_::StateAtPC, gr::GlobalRef) = Any[string(gr.mod, '.', gr.name)], nothing
163-
function _get_sig_type(s::StateAtPC, name::Symbol)
164-
sv = first(s)
165-
if istoplevel(sv)
166-
# this is concrete global variable, form the global reference
167-
return _get_sig_type(s, GlobalRef(sv.linfo.def, name))
168-
else
169-
return Any[repr(name; context = :compact => true)], nothing
170-
end
171-
end
163+
_get_sig_type(_::StateAtPC, name::Symbol) = Any[repr(name; context = :compact => true)], nothing
172164
function _get_sig_type(s::StateAtPC, gotoifnot::GotoIfNot)
173165
sig = Any[string("goto %", gotoifnot.dest, " if not "), _get_sig(s, gotoifnot.cond)...]
174166
return sig, nothing

src/typeinfer.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA
287287
throw_locs = get_throw_locs(analyzer)
288288
throw_calls = Tuple{Int,Expr}[]
289289
for (pc, stmt) in enumerate(stmts)
290-
is_throw_call_expr(analyzer, frame, stmt) || continue
290+
isa(stmt, Expr) || continue
291+
is_throw_call(stmt) || continue
291292
# if this `throw` is already reported, don't duplciate
292293
linetable[codelocs[pc]]::LineInfoNode in throw_locs && continue
293294
push!(throw_calls, (pc, stmt))
@@ -303,22 +304,3 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA
303304
empty!(get_uncaught_exceptions(analyzer))
304305
end
305306
end
306-
307-
# basically same as `is_throw_call`, but also toplevel module handling added
308-
function is_throw_call_expr(analyzer::AbstractAnalyzer, frame::InferenceState, @nospecialize(e))
309-
if isa(e, Expr)
310-
if e.head === :call
311-
f = e.args[1]
312-
if istoplevel(frame) && isa(f, Symbol)
313-
f = GlobalRef(get_toplevelmod(analyzer), f)
314-
end
315-
if isa(f, GlobalRef)
316-
ff = CC.abstract_eval_global(f.mod, f.name)
317-
if isa(ff, Const) && ff.val === Core.throw
318-
return true
319-
end
320-
end
321-
end
322-
end
323-
return false
324-
end

test/test_abstractinterpretation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ end
134134
@test isempty(get_reports(analyzer))
135135
end
136136

137-
# with the current approach, local undefined variables in toplevel frame can't be found
138-
# since we don't cache toplevel frame and thus it won't be optimized
139-
let
137+
let # should work for top-level analysis
140138
res = @analyze_toplevel begin
141139
foo = let
142140
if rand(Bool)
@@ -146,7 +144,9 @@ end
146144
end
147145
end
148146
end
149-
@test_broken !isempty(res.inference_error_reports)
147+
@test length(res.inference_error_reports) === 1 &&
148+
first(res.inference_error_reports) isa LocalUndefVarErrorReport &&
149+
first(res.inference_error_reports).name === :bar
150150
end
151151
end
152152

0 commit comments

Comments
 (0)