Skip to content

Commit 9fc33b9

Browse files
authored
Support "functors" for code reflection utilities (#58891)
As a follow-up to #57911, this updates: - `Base.method_instance` - `Base.method_instances` - `Base.code_ircode` - `Base.code_lowered` - `InteractiveUtils.code_llvm` - `InteractiveUtils.code_native` - `InteractiveUtils.code_warntype` to support "functor" invocations. e.g. `code_llvm((Foo, Int, Int))` which corresponds to `(::Foo)(::Int, ::Int)`
1 parent 144de95 commit 9fc33b9

File tree

6 files changed

+125
-38
lines changed

6 files changed

+125
-38
lines changed

Compiler/test/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
# The tests below assume a certain format and safepoint_on_entry=true breaks that.
2323
function get_llvm(@nospecialize(f), @nospecialize(t), raw=true, dump_module=false, optimize=true)
2424
params = Base.CodegenParams(safepoint_on_entry=false, gcstack_arg = false, debug_info_level=Cint(2))
25-
d = InteractiveUtils._dump_function(f, t, false, false, raw, dump_module, :att, optimize, :none, false, params)
25+
d = InteractiveUtils._dump_function(InteractiveUtils.ArgInfo(f, t), false, false, raw, dump_module, :att, optimize, :none, false, params)
2626
sprint(print, d)
2727
end
2828

base/reflection.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The keyword `debuginfo` controls the amount of code metadata present in the outp
1616
Note that an error will be thrown if `types` are not concrete types when `generated` is
1717
`true` and any of the corresponding methods are an `@generated` method.
1818
"""
19-
function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=true, debuginfo::Symbol=:default)
19+
function code_lowered(@nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}); generated::Bool=true, debuginfo::Symbol=:default)
2020
if @isdefined(IRShow)
2121
debuginfo = IRShow.debuginfo(debuginfo)
2222
elseif debuginfo === :default
@@ -28,7 +28,7 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
2828
world = get_world_counter()
2929
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
3030
ret = CodeInfo[]
31-
for m in method_instances(f, t, world)
31+
for m in method_instances(argtypes, world)
3232
if generated && hasgenerator(m)
3333
if may_invoke_generator(m)
3434
code = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)
@@ -46,12 +46,17 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
4646
return ret
4747
end
4848

49+
function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=true, debuginfo::Symbol=:default)
50+
tt = signature_type(f, t)
51+
return code_lowered(tt; generated, debuginfo)
52+
end
53+
4954
# for backwards compat
5055
const uncompressed_ast = uncompressed_ir
5156
const _uncompressed_ast = _uncompressed_ir
5257

53-
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
54-
tt = signature_type(f, t)
58+
function method_instances(@nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}), world::UInt)
59+
tt = to_tuple_type(argtypes)
5560
results = Core.MethodInstance[]
5661
# this make a better error message than the typeassert that follows
5762
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
@@ -62,15 +67,26 @@ function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
6267
return results
6368
end
6469

65-
function method_instance(@nospecialize(f), @nospecialize(t);
66-
world=Base.get_world_counter(), method_table=nothing)
70+
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
6771
tt = signature_type(f, t)
72+
return method_instances(tt, world)
73+
end
74+
75+
function method_instance(@nospecialize(argtypes::Union{Tuple,Type{<:Tuple}});
76+
world=Base.get_world_counter(), method_table=nothing)
77+
tt = to_tuple_type(argtypes)
6878
mi = ccall(:jl_method_lookup_by_tt, Any,
6979
(Any, Csize_t, Any),
7080
tt, world, method_table)
7181
return mi::Union{Nothing, MethodInstance}
7282
end
7383

84+
function method_instance(@nospecialize(f), @nospecialize(t);
85+
world=Base.get_world_counter(), method_table=nothing)
86+
tt = signature_type(f, t)
87+
return method_instance(tt; world, method_table)
88+
end
89+
7490
default_debug_info_kind() = unsafe_load(cglobal(:jl_default_debug_info_kind, Cint))
7591

7692
# this type mirrors jl_cgparams_t (documented in julia.h)
@@ -431,6 +447,11 @@ function code_ircode(@nospecialize(f), @nospecialize(types = default_tt(f)); kwa
431447
return code_ircode_by_type(tt; kwargs...)
432448
end
433449

450+
function code_ircode(@nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}); kwargs...)
451+
tt = to_tuple_type(argtypes)
452+
return code_ircode_by_type(tt; kwargs...)
453+
end
454+
434455
"""
435456
code_ircode_by_type(types::Type{<:Tuple}; ...)
436457

stdlib/InteractiveUtils/src/codeview.jl

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ const llstyle = Dict{Symbol, Tuple{Bool, Union{Symbol, Int}}}(
2020
:funcname => (false, :light_yellow),
2121
)
2222

23+
struct ArgInfo
24+
oc::Union{Core.OpaqueClosure,Nothing}
25+
tt::Type{<:Tuple}
26+
27+
# Construct from a function object + argtypes
28+
function ArgInfo(@nospecialize(f), @nospecialize(t))
29+
if isa(f, Core.Builtin)
30+
throw(ArgumentError("argument is not a generic function"))
31+
elseif f isa Core.OpaqueClosure
32+
return new(f, Base.to_tuple_type(t))
33+
else
34+
return new(nothing, signature_type(f, t))
35+
end
36+
end
37+
38+
# Construct from argtypes (incl. arg0)
39+
function ArgInfo(@nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}))
40+
tt = Base.to_tuple_type(argtypes)
41+
return new(nothing, tt)
42+
end
43+
end
44+
2345
function printstyled_ll(io::IO, x, s::Symbol, trailing_spaces="")
2446
printstyled(io, x, bold=llstyle[s][1], color=llstyle[s][2])
2547
print(io, trailing_spaces)
@@ -143,7 +165,7 @@ See the [`@code_warntype`](@ref man-code-warntype) section in the Performance Ti
143165
144166
See also: [`@code_warntype`](@ref), [`code_typed`](@ref), [`code_lowered`](@ref), [`code_llvm`](@ref), [`code_native`](@ref).
145167
"""
146-
function code_warntype(io::IO, @nospecialize(f), @nospecialize(tt=Base.default_tt(f));
168+
function code_warntype(io::IO, arginfo::ArgInfo;
147169
world=Base.get_world_counter(),
148170
interp::Base.Compiler.AbstractInterpreter=Base.Compiler.NativeInterpreter(world),
149171
debuginfo::Symbol=:default, optimize::Bool=false, kwargs...)
@@ -152,13 +174,14 @@ function code_warntype(io::IO, @nospecialize(f), @nospecialize(tt=Base.default_t
152174
debuginfo = Base.IRShow.debuginfo(debuginfo)
153175
lineprinter = Base.IRShow.__debuginfo[debuginfo]
154176
nargs::Int = 0
155-
if isa(f, Core.OpaqueClosure)
156-
isa(f.source, Method) && (nargs = f.source.nargs)
157-
print_warntype_codeinfo(io, Base.code_typed_opaque_closure(f, tt)[1]..., nargs;
177+
if arginfo.oc !== nothing
178+
(; oc, tt) = arginfo
179+
isa(oc.source, Method) && (nargs = oc.source.nargs)
180+
print_warntype_codeinfo(io, Base.code_typed_opaque_closure(oc, tt)[1]..., nargs;
158181
lineprinter, label_dynamic_calls = optimize)
159182
return nothing
160183
end
161-
tt = Base.signature_type(f, tt)
184+
tt = arginfo.tt
162185
matches = findall(tt, Base.Compiler.method_table(interp))
163186
matches === nothing && Base.raise_match_failure(:code_warntype, tt)
164187
for match in matches.matches
@@ -176,6 +199,8 @@ function code_warntype(io::IO, @nospecialize(f), @nospecialize(tt=Base.default_t
176199
end
177200
nothing
178201
end
202+
code_warntype(io::IO, @nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...) = code_warntype(io, ArgInfo(f, tt); kwargs...)
203+
code_warntype(io::IO, @nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}); kwargs...) = code_warntype(io, ArgInfo(argtypes); kwargs...)
179204
code_warntype(args...; kwargs...) = (@nospecialize; code_warntype(stdout, args...; kwargs...))
180205

181206
using Base: CodegenParams
@@ -189,33 +214,30 @@ const OC_MISMATCH_WARNING =
189214

190215
# Printing code representations in IR and assembly
191216

192-
function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrapper::Bool,
217+
function _dump_function(arginfo::ArgInfo, native::Bool, wrapper::Bool,
193218
raw::Bool, dump_module::Bool, syntax::Symbol,
194219
optimize::Bool, debuginfo::Symbol, binary::Bool,
195220
params::CodegenParams=CodegenParams(debug_info_kind=Cint(0), debug_info_level=Cint(2), safepoint_on_entry=raw, gcstack_arg=raw))
196221
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
197-
if isa(f, Core.Builtin)
198-
throw(ArgumentError("argument is not a generic function"))
199-
end
200222
warning = ""
201223
# get the MethodInstance for the method match
202-
if !isa(f, Core.OpaqueClosure)
224+
if arginfo.oc === nothing
203225
world = Base.get_world_counter()
204-
match = Base._which(signature_type(f, t); world)
226+
match = Base._which(arginfo.tt; world)
205227
mi = Base.specialize_method(match)
206228
# TODO: use jl_is_cacheable_sig instead of isdispatchtuple
207229
isdispatchtuple(mi.specTypes) || (warning = GENERIC_SIG_WARNING)
208230
else
209-
world = UInt64(f.world)
210-
tt = Base.to_tuple_type(t)
211-
if !isdefined(f.source, :source)
231+
(; oc, tt) = arginfo
232+
world = UInt64(oc.world)
233+
if !isdefined(oc.source, :source)
212234
# OC was constructed from inferred source. There's only one
213235
# specialization and we can't infer anything more precise either.
214-
world = f.source.primary_world
215-
mi = f.source.specializations::Core.MethodInstance
216-
Base.hasintersect(typeof(f).parameters[1], tt) || (warning = OC_MISMATCH_WARNING)
236+
world = oc.source.primary_world
237+
mi = oc.source.specializations::Core.MethodInstance
238+
Base.hasintersect(typeof(oc).parameters[1], tt) || (warning = OC_MISMATCH_WARNING)
217239
else
218-
mi = Base.specialize_method(f.source, Tuple{typeof(f.captures), tt.parameters...}, Core.svec())
240+
mi = Base.specialize_method(oc.source, Tuple{typeof(oc.captures), tt.parameters...}, Core.svec())
219241
isdispatchtuple(mi.specTypes) || (warning = GENERIC_SIG_WARNING)
220242
end
221243
end
@@ -236,19 +258,19 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
236258
end
237259
if isempty(str)
238260
# if that failed (or we want metadata), use LLVM to generate more accurate assembly output
239-
if !isa(f, Core.OpaqueClosure)
261+
if arginfo.oc === nothing
240262
src = Base.Compiler.typeinf_code(Base.Compiler.NativeInterpreter(world), mi, true)
241263
else
242-
src, rt = Base.get_oc_code_rt(nothing, f, tt, true)
264+
src, rt = Base.get_oc_code_rt(nothing, arginfo.oc, arginfo.tt, true)
243265
end
244266
src isa Core.CodeInfo || error("failed to infer source for $mi")
245267
str = _dump_function_native_assembly(mi, src, wrapper, syntax, debuginfo, binary, raw, params)
246268
end
247269
else
248-
if !isa(f, Core.OpaqueClosure)
270+
if arginfo.oc === nothing
249271
src = Base.Compiler.typeinf_code(Base.Compiler.NativeInterpreter(world), mi, true)
250272
else
251-
src, rt = Base.get_oc_code_rt(nothing, f, tt, true)
273+
src, rt = Base.get_oc_code_rt(nothing, arginfo.oc, arginfo.tt, true)
252274
end
253275
src isa Core.CodeInfo || error("failed to infer source for $mi")
254276
str = _dump_function_llvm(mi, src, wrapper, !raw, dump_module, optimize, debuginfo, params)
@@ -311,16 +333,18 @@ Keyword argument `debuginfo` may be one of source (default) or none, to specify
311333
312334
See also: [`@code_llvm`](@ref), [`code_warntype`](@ref), [`code_typed`](@ref), [`code_lowered`](@ref), [`code_native`](@ref).
313335
"""
314-
function code_llvm(io::IO, @nospecialize(f), @nospecialize(types=Base.default_tt(f));
336+
function code_llvm(io::IO, arginfo::ArgInfo;
315337
raw::Bool=false, dump_module::Bool=false, optimize::Bool=true, debuginfo::Symbol=:default,
316338
params::CodegenParams=CodegenParams(debug_info_kind=Cint(0), debug_info_level=Cint(2), safepoint_on_entry=raw, gcstack_arg=raw))
317-
d = _dump_function(f, types, false, false, raw, dump_module, :intel, optimize, debuginfo, false, params)
339+
d = _dump_function(arginfo, false, false, raw, dump_module, :intel, optimize, debuginfo, false, params)
318340
if highlighting[:llvm] && get(io, :color, false)::Bool
319341
print_llvm(io, d)
320342
else
321343
print(io, d)
322344
end
323345
end
346+
code_llvm(io::IO, @nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}); kwargs...) = code_llvm(io, ArgInfo(argtypes); kwargs...)
347+
code_llvm(io::IO, @nospecialize(f), @nospecialize(types=Base.default_tt(f)); kwargs...) = code_llvm(io, ArgInfo(f, types); kwargs...)
324348
code_llvm(args...; kwargs...) = (@nospecialize; code_llvm(stdout, args...; kwargs...))
325349

326350
"""
@@ -337,17 +361,19 @@ generic function and type signature to `io`.
337361
338362
See also: [`@code_native`](@ref), [`code_warntype`](@ref), [`code_typed`](@ref), [`code_lowered`](@ref), [`code_llvm`](@ref).
339363
"""
340-
function code_native(io::IO, @nospecialize(f), @nospecialize(types=Base.default_tt(f));
364+
function code_native(io::IO, arginfo::ArgInfo;
341365
dump_module::Bool=true, syntax::Symbol=:intel, raw::Bool=false,
342366
debuginfo::Symbol=:default, binary::Bool=false,
343367
params::CodegenParams=CodegenParams(debug_info_kind=Cint(0), debug_info_level=Cint(2), safepoint_on_entry=raw, gcstack_arg=raw))
344-
d = _dump_function(f, types, true, false, raw, dump_module, syntax, true, debuginfo, binary, params)
368+
d = _dump_function(arginfo, true, false, raw, dump_module, syntax, true, debuginfo, binary, params)
345369
if highlighting[:native] && get(io, :color, false)::Bool
346370
print_native(io, d)
347371
else
348372
print(io, d)
349373
end
350374
end
375+
code_native(io::IO, @nospecialize(argtypes::Union{Tuple,Type{<:Tuple}}); kwargs...) = code_native(io, ArgInfo(argtypes); kwargs...)
376+
code_native(io::IO, @nospecialize(f), @nospecialize(types=Base.default_tt(f)); kwargs...) = code_native(io, ArgInfo(f, types); kwargs...)
351377
code_native(args...; kwargs...) = (@nospecialize; code_native(stdout, args...; kwargs...))
352378

353379
## colorized IR and assembly printing

stdlib/InteractiveUtils/test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,9 @@ end # module ReflectionTest
592592
# Issue #18883, code_llvm/code_native for generated functions
593593
@generated f18883() = nothing
594594
@test !isempty(sprint(code_llvm, f18883, Tuple{}))
595+
@test !isempty(sprint(code_llvm, (typeof(f18883),)))
595596
@test !isempty(sprint(code_native, f18883, Tuple{}))
597+
@test !isempty(sprint(code_native, (typeof(f18883),)))
596598

597599
ix86 = r"i[356]86"
598600

@@ -865,6 +867,27 @@ let # `default_tt` should work with any function with one method
865867
end); true)
866868
end
867869

870+
let # specifying calls as argtypes (incl. arg0) should be supported
871+
@test (code_warntype(devnull, (typeof(function ()
872+
sin(42)
873+
end),)); true)
874+
@test (code_warntype(devnull, (typeof(function (a::Int)
875+
sin(42)
876+
end), Int)); true)
877+
@test (code_llvm(devnull, (typeof(function ()
878+
sin(42)
879+
end),)); true)
880+
@test (code_llvm(devnull, (typeof(function (a::Int)
881+
sin(42)
882+
end), Int)); true)
883+
@test (code_native(devnull, (typeof(function ()
884+
sin(42)
885+
end),)); true)
886+
@test (code_native(devnull, (typeof(function (a::Int)
887+
sin(42)
888+
end), Int)); true)
889+
end
890+
868891
@testset "code_llvm on opaque_closure" begin
869892
let ci = code_typed(+, (Int, Int))[1][1]
870893
ir = Core.Compiler.inflate_ir(ci)

test/rebinding.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ module RangeMerge
285285

286286
function get_llvm(@nospecialize(f), @nospecialize(t), raw=true, dump_module=false, optimize=true)
287287
params = Base.CodegenParams(safepoint_on_entry=false, gcstack_arg = false, debug_info_level=Cint(2))
288-
d = InteractiveUtils._dump_function(f, t, false, false, raw, dump_module, :att, optimize, :none, false, params)
288+
d = InteractiveUtils._dump_function(InteractiveUtils.ArgInfo(f, t), false, false, raw, dump_module, :att, optimize, :none, false, params)
289289
sprint(print, d)
290290
end
291291

test/reflection.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ function test_ir_reflection(freflect, f, types)
1616
nothing
1717
end
1818

19+
function test_ir_reflection(freflect, argtypes)
20+
@test !isempty(freflect(argtypes))
21+
nothing
22+
end
23+
1924
function test_bin_reflection(freflect, f, types)
2025
iob = IOBuffer()
2126
freflect(iob, f, types)
@@ -27,6 +32,9 @@ end
2732
function test_code_reflection(freflect, f, types, tester)
2833
tester(freflect, f, types)
2934
tester(freflect, f, (types.parameters...,))
35+
tt = Base.signature_type(f, types)
36+
tester(freflect, tt)
37+
tester(freflect, (tt.parameters...,))
3038
nothing
3139
end
3240

@@ -43,6 +51,7 @@ end
4351

4452
test_code_reflections(test_ir_reflection, code_lowered)
4553
test_code_reflections(test_ir_reflection, code_typed)
54+
test_code_reflections(test_ir_reflection, Base.code_ircode)
4655

4756
io = IOBuffer()
4857
Base.print_statement_costs(io, map, (typeof(sqrt), Tuple{Int}))
@@ -682,6 +691,10 @@ end
682691
@test Base.code_typed_by_type(Tuple{Type{<:Val}})[2][2] == Val
683692
@test Base.code_typed_by_type(Tuple{typeof(sin), Float64})[1][2] === Float64
684693

694+
# functor-like code_typed(...)
695+
@test Base.code_typed((Type{<:Val},))[2][2] == Val
696+
@test Base.code_typed((typeof(sin), Float64))[1][2] === Float64
697+
685698
# New reflection methods in 0.6
686699
struct ReflectionExample{T<:AbstractFloat, N}
687700
x::Tuple{T, N}
@@ -1038,11 +1051,12 @@ _test_at_locals2(1,1,0.5f0)
10381051

10391052
@testset "issue #31687" begin
10401053
import InteractiveUtils._dump_function
1054+
import InteractiveUtils.ArgInfo
10411055

10421056
@noinline f31687_child(i) = f31687_nonexistent(i)
10431057
f31687_parent() = f31687_child(0)
10441058
params = Base.CodegenParams()
1045-
_dump_function(f31687_parent, Tuple{},
1059+
_dump_function(ArgInfo(f31687_parent, Tuple{}),
10461060
#=native=#false, #=wrapper=#false, #=raw=#true,
10471061
#=dump_module=#true, #=syntax=#:att, #=optimize=#false, :none,
10481062
#=binary=#false)
@@ -1131,9 +1145,12 @@ end
11311145
@test 1+1 == 2
11321146
mi1 = Base.method_instance(+, (Int, Int))
11331147
@test mi1.def.name == :+
1134-
# Note `jl_method_lookup` doesn't returns CNull if not found
1135-
mi2 = @ccall jl_method_lookup(Any[+, 1, 1]::Ptr{Any}, 3::Csize_t, Base.get_world_counter()::Csize_t)::Ref{Core.MethodInstance}
1136-
@test mi1 == mi2
1148+
mi2 = Base.method_instance((typeof(+), Int, Int))
1149+
@test mi2.def.name == :+
1150+
# Note `jl_method_lookup` doesn't return CNull if not found
1151+
mi3 = @ccall jl_method_lookup(Any[+, 1, 1]::Ptr{Any}, 3::Csize_t, Base.get_world_counter()::Csize_t)::Ref{Core.MethodInstance}
1152+
@test mi1 == mi3
1153+
@test mi2 == mi3
11371154
end
11381155

11391156
Base.@assume_effects :terminates_locally function issue41694(x::Int)

0 commit comments

Comments
 (0)