Skip to content

Commit ba766f9

Browse files
authored
implement a system for interpreting overlay methods (#682)
* change the `recurse` interface to `AbstractInterpreter`-like interface Align the `recurse` argument to something like the base Compiler's `AbstractInterpreter` and make JuliaInterpreter routines overloadable properly. This change is quite breaking (thus bumping the minor version of this package), but necessary to enhance the customizability of JI. For example, it will make it easier to add changes like #682 in a nicer way, but also should enable better designs in packages such as Revise and JET. * implement a system for interpreting overlay methods Implemented, being inspired by #680, thinking it might be useful in the near future for JET's use case. The overlay interpretertation is enabled by overloading `JuliaInterpreter.method_table(recurse)`, but maybe the entire `recurse`-overload mechanism itself needs to be aligned with the `AbstractInterpreter` design and properly organized. Since this interface will also be needed for JET, I'll probably work on it soon.
1 parent 34cf432 commit ba766f9

File tree

7 files changed

+73
-17
lines changed

7 files changed

+73
-17
lines changed

docs/src/dev_reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ JuliaInterpreter.Interpreter
1313
JuliaInterpreter.RecursiveInterpreter
1414
JuliaInterpreter.NonRecursiveInterpreter
1515
JuliaInterpreter.Compiled
16+
JuliaInterpreter.method_table
1617
```
1718

1819
## Frame creation

src/construct.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ julia> argtypes
229229
Tuple{typeof(mymethod), Vector{Float64}}
230230
```
231231
"""
232-
function prepare_call(@nospecialize(f), allargs; enter_generated = false)
232+
function prepare_call(@nospecialize(f), allargs;
233+
enter_generated::Bool=false,
234+
method_table::Union{Nothing,MethodTable}=nothing)
233235
# Can happen for thunks created by generated functions
234236
if isa(f, Core.Builtin) || isa(f, Core.IntrinsicFunction)
235237
return nothing
@@ -247,13 +249,13 @@ function prepare_call(@nospecialize(f), allargs; enter_generated = false)
247249
return nothing
248250
end
249251
else
250-
method = whichtt(argtypes)
252+
method = whichtt(argtypes, method_table)
251253
end
252254
if method === nothing
253255
# Call it to generate the exact error
254256
return f(allargs[2:end]...)
255257
end
256-
ret = prepare_framecode(method, argtypes; enter_generated=enter_generated)
258+
ret = prepare_framecode(method, argtypes; enter_generated)
257259
# Exceptional returns
258260
if ret === nothing
259261
# The generator threw an error. Let's generate the same error by calling it.
@@ -580,7 +582,9 @@ Prepare all the information needed to execute a particular `:call` expression `e
580582
For example, try `JuliaInterpreter.determine_method_for_expr(:(\$sum([1,2])))`.
581583
See [`JuliaInterpreter.prepare_call`](@ref) for information about the outputs.
582584
"""
583-
function determine_method_for_expr(expr; enter_generated = false)
585+
function determine_method_for_expr(expr::Expr;
586+
enter_generated::Bool=false,
587+
method_table::Union{Nothing,MethodTable}=nothing)
584588
f = to_function(expr.args[1])
585589
allargs = expr.args
586590
# Extract keyword args
@@ -589,7 +593,7 @@ function determine_method_for_expr(expr; enter_generated = false)
589593
kwargs = splice!(allargs, 2)::Expr
590594
end
591595
f, allargs = prepare_args(f, allargs, kwargs.args)
592-
return prepare_call(f, allargs; enter_generated=enter_generated)
596+
return prepare_call(f, allargs; enter_generated, method_table)
593597
end
594598

595599
"""
@@ -626,9 +630,11 @@ T = Float64
626630
627631
See [`enter_call`](@ref) for a similar approach not based on expressions.
628632
"""
629-
function enter_call_expr(expr; enter_generated = false)
633+
function enter_call_expr(expr::Expr;
634+
enter_generated::Bool=false,
635+
method_table::Union{Nothing,MethodTable}=nothing)
630636
clear_caches()
631-
r = determine_method_for_expr(expr; enter_generated = enter_generated)
637+
r = determine_method_for_expr(expr; enter_generated, method_table)
632638
if r !== nothing && !isa(r[1], Compiled)
633639
return prepare_frame(Base.front(r)...)
634640
end
@@ -666,7 +672,9 @@ would be created by the generator.
666672
667673
See [`enter_call_expr`](@ref) for a similar approach based on expressions.
668674
"""
669-
function enter_call(@nospecialize(finfo), @nospecialize(args...); kwargs...)
675+
function enter_call(@nospecialize(finfo), @nospecialize(args...);
676+
method_table::Union{Nothing,MethodTable}=nothing,
677+
kwargs...)
670678
clear_caches()
671679
if isa(finfo, Tuple)
672680
f = finfo[1]
@@ -680,7 +688,7 @@ function enter_call(@nospecialize(finfo), @nospecialize(args...); kwargs...)
680688
if isa(f, Core.Builtin) || isa(f, Core.IntrinsicFunction)
681689
error(f, " is a builtin or intrinsic")
682690
end
683-
r = prepare_call(f, allargs; enter_generated=enter_generated)
691+
r = prepare_call(f, allargs; enter_generated, method_table)
684692
if r !== nothing && !isa(r[1], Compiled)
685693
return prepare_frame(Base.front(r)...)
686694
end

src/interpret.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ function evaluate_call!(interp::Interpreter, frame::Frame, call_expr::Expr, ente
260260
lenv === nothing && return framecode # this was a Builtin
261261
fargs = fargs_pruned
262262
else
263-
framecode, lenv = get_call_framecode(fargs, frame.framecode, frame.pc; enter_generated=enter_generated)
263+
method_table = JuliaInterpreter.method_table(interp)
264+
framecode, lenv = get_call_framecode(fargs, frame.framecode, frame.pc;
265+
enter_generated, method_table)
264266
if lenv === nothing
265267
if isa(framecode, Compiled)
266268
return native_call(fargs, frame)

src/localmethtable.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ Return the framecode and environment for a call specified by `fargs = [f, args..
77
`parentframecode` is the caller, and `idx` is the program-counter index.
88
If possible, `framecode` will be looked up from the local method tables of `parentframe`.
99
"""
10-
function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int; enter_generated::Bool=false)
10+
function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int;
11+
enter_generated::Bool=false,
12+
method_table::Union{Nothing,MethodTable}=nothing)
1113
nargs = length(fargs) # includes f as the first "argument"
1214
# Determine whether we can look up the appropriate framecode in the local method table
1315
if isassigned(parentframe.methodtables, idx) # if this is the first call, this may not yet be set
@@ -60,7 +62,7 @@ function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int
6062
end
6163
# We haven't yet encountered this argtype combination and need to look it up by dispatch
6264
fargs[1] = f = to_function(fargs[1])
63-
ret = prepare_call(f, fargs; enter_generated=enter_generated)
65+
ret = prepare_call(f, fargs; enter_generated, method_table)
6466
ret === nothing && return invokelatest(f, fargs[2:end]...), nothing
6567
is_compiled = isa(ret[1], Compiled)
6668
local framecode

src/types.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ redefined as a completely different type in v0.11 or later.
3939
const Compiled = NonRecursiveInterpreter # for backward compatibility
4040
Base.similar(::Compiled, sz) = Compiled() # to support similar(stack, 0)
4141

42+
"""
43+
method_table(interpreter::Interpreter) -> mt::Union{Nothing,MethodTable}
44+
45+
Configures the method table used for method lookups performed by the interpreter.
46+
Uses the global method table by default.
47+
"""
48+
method_table(::Interpreter) = nothing
49+
4250
# Our own replacements for Core types. We need to do this to ensure we can tell the difference
4351
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
4452
struct SSAValue

src/utils.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,29 @@ end
3131
Like `which` except it operates on the complete tuple-type `tt`,
3232
and doesn't throw when there is no matching method.
3333
"""
34-
function whichtt(@nospecialize(tt), mt::Union{Nothing, MethodTable} = nothing)
34+
function whichtt(@nospecialize(tt), mt::Union{Nothing,MethodTable}=nothing)
3535
# TODO: provide explicit control over world age? In case we ever need to call "old" methods.
36-
# branch on https://github.com/JuliaLang/julia/pull/44515
37-
# for now, code execution doesn't have the capability to use an overlayed method table,
38-
# which is meant to be addressed in https://github.com/JuliaDebug/JuliaInterpreter.jl/pull/682.
39-
match, _ = Core.Compiler._findsup(tt, mt, get_world_counter())
36+
# TODO Use `CachedMethodTable` for better performance once `teh/worldage` is merged
37+
match, _ = findsup_mt(tt, Base.get_world_counter(), mt)
4038
match === nothing && return nothing
4139
return match.method
4240
end
4341

42+
@static if VERSION v"1.12-"
43+
using Base.Compiler: findsup_mt
44+
else
45+
function findsup_mt(@nospecialize(tt), world, method_table)
46+
if method_table === nothing
47+
table = Core.Compiler.InternalMethodTable(world)
48+
elseif method_table isa Core.MethodTable
49+
table = Core.Compiler.OverlayMethodTable(world, method_table)
50+
else
51+
table = method_table
52+
end
53+
return Core.Compiler.findsup(tt, table)
54+
end
55+
end
56+
4457
instantiate_type_in_env(arg, spsig::UnionAll, spvals::Vector{Any}) =
4558
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), arg, spsig, spvals)
4659

test/interpret.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,25 @@ end
975975
@test_throws "Invalid @interpret call" macroexpand(@__MODULE__, :(@interpret interp sin(42)))
976976
@test_throws "Invalid @interpret call" macroexpand(@__MODULE__, :(@interpret _interp_=RecursiveInterpreter() sin(42)))
977977
end
978+
979+
function func_overlay end
980+
func_overlay(x) = sin(x)
981+
call_func_overlay(x) = func_overlay(x)
982+
Base.Experimental.@MethodTable ex_method_table
983+
Base.Experimental.@overlay ex_method_table func_overlay(x) = cos(x)
984+
struct OverlayInterpreter <: Interpreter end
985+
JuliaInterpreter.method_table(::OverlayInterpreter) = ex_method_table
986+
987+
@testset "Interpret overlay method" begin
988+
let frame = JuliaInterpreter.Frame(@__MODULE__, :(func_overlay(42.0)))
989+
@test JuliaInterpreter.finish_and_return!(frame, true) == sin(42.0)
990+
end
991+
@test sin(42.0) == @interpret call_func_overlay(42.0)
992+
let frame = JuliaInterpreter.Frame(@__MODULE__, :(func_overlay(42.0)))
993+
@test JuliaInterpreter.finish_and_return!(OverlayInterpreter(), frame, true) == cos(42.0)
994+
end
995+
let frame = JuliaInterpreter.enter_call(func_overlay, 42.0; method_table=ex_method_table)
996+
@test JuliaInterpreter.finish_and_return!(OverlayInterpreter(), frame) == cos(42.0)
997+
end
998+
@test cos(42.0) == @interpret interp=OverlayInterpreter() call_func_overlay(42.0)
999+
end

0 commit comments

Comments
 (0)