Skip to content

Commit 1238d1a

Browse files
committed
implement a system for interpreting overlay methods
Implemented, being inspired by #680, thinking it might be useful in the near future for JET's use case. The overlay interpretertation is enabled by overloading `JuliaInterpreter.method_table(recurse)`, but maybe the entire `recurse`-overload mechanism itself needs to be aligned with the `AbstractInterpreter` design and properly organized. Since this interface will also be needed for JET, I'll probably work on it soon.
1 parent be1fb4f commit 1238d1a

File tree

6 files changed

+72
-17
lines changed

6 files changed

+72
-17
lines changed

src/construct.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ julia> argtypes
229229
Tuple{typeof(mymethod), Vector{Float64}}
230230
```
231231
"""
232-
function prepare_call(@nospecialize(f), allargs; enter_generated = false)
232+
function prepare_call(@nospecialize(f), allargs;
233+
enter_generated::Bool=false,
234+
method_table::Union{Nothing,MethodTable}=nothing)
233235
# Can happen for thunks created by generated functions
234236
if isa(f, Core.Builtin) || isa(f, Core.IntrinsicFunction)
235237
return nothing
@@ -247,13 +249,13 @@ function prepare_call(@nospecialize(f), allargs; enter_generated = false)
247249
return nothing
248250
end
249251
else
250-
method = whichtt(argtypes)
252+
method = whichtt(argtypes, method_table)
251253
end
252254
if method === nothing
253255
# Call it to generate the exact error
254256
return f(allargs[2:end]...)
255257
end
256-
ret = prepare_framecode(method, argtypes; enter_generated=enter_generated)
258+
ret = prepare_framecode(method, argtypes; enter_generated)
257259
# Exceptional returns
258260
if ret === nothing
259261
# The generator threw an error. Let's generate the same error by calling it.
@@ -580,7 +582,9 @@ Prepare all the information needed to execute a particular `:call` expression `e
580582
For example, try `JuliaInterpreter.determine_method_for_expr(:(\$sum([1,2])))`.
581583
See [`JuliaInterpreter.prepare_call`](@ref) for information about the outputs.
582584
"""
583-
function determine_method_for_expr(expr; enter_generated = false)
585+
function determine_method_for_expr(expr::Expr;
586+
enter_generated::Bool=false,
587+
method_table::Union{Nothing,MethodTable}=nothing)
584588
f = to_function(expr.args[1])
585589
allargs = expr.args
586590
# Extract keyword args
@@ -589,7 +593,7 @@ function determine_method_for_expr(expr; enter_generated = false)
589593
kwargs = splice!(allargs, 2)::Expr
590594
end
591595
f, allargs = prepare_args(f, allargs, kwargs.args)
592-
return prepare_call(f, allargs; enter_generated=enter_generated)
596+
return prepare_call(f, allargs; enter_generated, method_table)
593597
end
594598

595599
"""
@@ -626,9 +630,11 @@ T = Float64
626630
627631
See [`enter_call`](@ref) for a similar approach not based on expressions.
628632
"""
629-
function enter_call_expr(expr; enter_generated = false)
633+
function enter_call_expr(expr::Expr;
634+
enter_generated::Bool=false,
635+
method_table::Union{Nothing,MethodTable}=nothing)
630636
clear_caches()
631-
r = determine_method_for_expr(expr; enter_generated = enter_generated)
637+
r = determine_method_for_expr(expr; enter_generated, method_table)
632638
if r !== nothing && !isa(r[1], Compiled)
633639
return prepare_frame(Base.front(r)...)
634640
end
@@ -666,7 +672,9 @@ would be created by the generator.
666672
667673
See [`enter_call_expr`](@ref) for a similar approach based on expressions.
668674
"""
669-
function enter_call(@nospecialize(finfo), @nospecialize(args...); kwargs...)
675+
function enter_call(@nospecialize(finfo), @nospecialize(args...);
676+
method_table::Union{Nothing,MethodTable}=nothing,
677+
kwargs...)
670678
clear_caches()
671679
if isa(finfo, Tuple)
672680
f = finfo[1]
@@ -680,7 +688,7 @@ function enter_call(@nospecialize(finfo), @nospecialize(args...); kwargs...)
680688
if isa(f, Core.Builtin) || isa(f, Core.IntrinsicFunction)
681689
error(f, " is a builtin or intrinsic")
682690
end
683-
r = prepare_call(f, allargs; enter_generated=enter_generated)
691+
r = prepare_call(f, allargs; enter_generated, method_table)
684692
if r !== nothing && !isa(r[1], Compiled)
685693
return prepare_frame(Base.front(r)...)
686694
end

src/interpret.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ function evaluate_call!(interp::Interpreter, frame::Frame, call_expr::Expr, ente
260260
lenv === nothing && return framecode # this was a Builtin
261261
fargs = fargs_pruned
262262
else
263-
framecode, lenv = get_call_framecode(fargs, frame.framecode, frame.pc; enter_generated=enter_generated)
263+
method_table = JuliaInterpreter.method_table(interp)
264+
framecode, lenv = get_call_framecode(fargs, frame.framecode, frame.pc;
265+
enter_generated, method_table)
264266
if lenv === nothing
265267
if isa(framecode, Compiled)
266268
return native_call(fargs, frame)

src/localmethtable.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ Return the framecode and environment for a call specified by `fargs = [f, args..
77
`parentframecode` is the caller, and `idx` is the program-counter index.
88
If possible, `framecode` will be looked up from the local method tables of `parentframe`.
99
"""
10-
function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int; enter_generated::Bool=false)
10+
function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int;
11+
enter_generated::Bool=false,
12+
method_table::Union{Nothing,MethodTable}=nothing)
1113
nargs = length(fargs) # includes f as the first "argument"
1214
# Determine whether we can look up the appropriate framecode in the local method table
1315
if isassigned(parentframe.methodtables, idx) # if this is the first call, this may not yet be set
@@ -60,7 +62,7 @@ function get_call_framecode(fargs::Vector{Any}, parentframe::FrameCode, idx::Int
6062
end
6163
# We haven't yet encountered this argtype combination and need to look it up by dispatch
6264
fargs[1] = f = to_function(fargs[1])
63-
ret = prepare_call(f, fargs; enter_generated=enter_generated)
65+
ret = prepare_call(f, fargs; enter_generated, method_table)
6466
ret === nothing && return invokelatest(f, fargs[2:end]...), nothing
6567
is_compiled = isa(ret[1], Compiled)
6668
local framecode

src/types.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ redefined as a completely different type in v0.11 or later.
3939
const Compiled = NonRecursiveInterpreter # for backward compatibility
4040
Base.similar(::Compiled, sz) = Compiled() # to support similar(stack, 0)
4141

42+
"""
43+
method_table(interpreter::Interpreter) -> mt::Union{Nothing,MethodTable}
44+
45+
Configures the method table used for method lookups performed by the interpreter.
46+
Uses the global method table by default.
47+
"""
48+
method_table(::Interpreter) = nothing
49+
4250
# Our own replacements for Core types. We need to do this to ensure we can tell the difference
4351
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
4452
struct SSAValue

src/utils.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,29 @@ end
3131
Like `which` except it operates on the complete tuple-type `tt`,
3232
and doesn't throw when there is no matching method.
3333
"""
34-
function whichtt(@nospecialize(tt), mt::Union{Nothing, MethodTable} = nothing)
34+
function whichtt(@nospecialize(tt), mt::Union{Nothing,MethodTable}=nothing)
3535
# TODO: provide explicit control over world age? In case we ever need to call "old" methods.
36-
# branch on https://github.com/JuliaLang/julia/pull/44515
37-
# for now, code execution doesn't have the capability to use an overlayed method table,
38-
# which is meant to be addressed in https://github.com/JuliaDebug/JuliaInterpreter.jl/pull/682.
39-
match, _ = Core.Compiler._findsup(tt, mt, get_world_counter())
36+
# TODO Use `CachedMethodTable` for better performance once `teh/worldage` is merged
37+
match, _ = findsup_mt(tt, Base.get_world_counter(), mt)
4038
match === nothing && return nothing
4139
return match.method
4240
end
4341

42+
@static if VERSION v"1.12-"
43+
using Base.Compiler: findsup_mt
44+
else
45+
function findsup_mt(@nospecialize(tt), world, method_table)
46+
if method_table === nothing
47+
table = Core.Compiler.InternalMethodTable(world)
48+
elseif method_table isa Core.MethodTable
49+
table = Core.Compiler.OverlayMethodTable(world, method_table)
50+
else
51+
table = method_table
52+
end
53+
return Core.Compiler.findsup(tt, table)
54+
end
55+
end
56+
4457
instantiate_type_in_env(arg, spsig::UnionAll, spvals::Vector{Any}) =
4558
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), arg, spsig, spvals)
4659

test/interpret.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,25 @@ end
975975
@test_throws "Invalid @interpret call" macroexpand(@__MODULE__, :(@interpret interp sin(42)))
976976
@test_throws "Invalid @interpret call" macroexpand(@__MODULE__, :(@interpret _interp_=RecursiveInterpreter() sin(42)))
977977
end
978+
979+
function func_overlay end
980+
func_overlay(x) = sin(x)
981+
call_func_overlay(x) = func_overlay(x)
982+
Base.Experimental.@MethodTable ex_method_table
983+
Base.Experimental.@overlay ex_method_table func_overlay(x) = cos(x)
984+
struct OverlayInterpreter <: Interpreter end
985+
JuliaInterpreter.method_table(::OverlayInterpreter) = ex_method_table
986+
987+
@testset "Interpret overlay method" begin
988+
let frame = JuliaInterpreter.Frame(@__MODULE__, :(func_overlay(42.0)))
989+
@test JuliaInterpreter.finish_and_return!(frame, true) == sin(42.0)
990+
end
991+
@test sin(42.0) == @interpret call_func_overlay(42.0)
992+
let frame = JuliaInterpreter.Frame(@__MODULE__, :(func_overlay(42.0)))
993+
@test JuliaInterpreter.finish_and_return!(OverlayInterpreter(), frame, true) == cos(42.0)
994+
end
995+
let frame = JuliaInterpreter.enter_call(func_overlay, 42.0; method_table=ex_method_table)
996+
@test JuliaInterpreter.finish_and_return!(OverlayInterpreter(), frame) == cos(42.0)
997+
end
998+
@test cos(42.0) == @interpret interp=OverlayInterpreter() call_func_overlay(42.0)
999+
end

0 commit comments

Comments
 (0)