Skip to content

Commit f2d67b2

Browse files
authored
Post-merge cleanup of FunctionSpec type change (#323)
1 parent 1628800 commit f2d67b2

File tree

10 files changed

+207
-31
lines changed

10 files changed

+207
-31
lines changed

src/irgen.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ end
389389
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
390390
ctx = context(mod)
391391
ft = eltype(llvmtype(f))
392-
@compiler_assert return_type(ft) == LLVM.VoidType(ctx) job
392+
@compiler_assert LLVM.return_type(ft) == LLVM.VoidType(ctx) job
393393

394394
# find the byval parameters
395395
byval = BitVector(undef, length(parameters(ft)))
@@ -407,7 +407,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
407407
has_kernel_state = kernel_state_type(job) !== Nothing
408408
orig_ft = if has_kernel_state
409409
# the kernel state has been added here already, so strip the first parameter
410-
LLVM.FunctionType(return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
410+
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
411411
else
412412
ft
413413
end
@@ -462,7 +462,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
462462
push!(new_types, param)
463463
end
464464
end
465-
new_ft = LLVM.FunctionType(return_type(ft), new_types)
465+
new_ft = LLVM.FunctionType(LLVM.return_type(ft), new_types)
466466
new_f = LLVM.Function(mod, "", new_ft)
467467
linkage!(new_f, linkage(f))
468468
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
@@ -595,7 +595,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
595595

596596
# create a new function
597597
new_param_types = [T_state, parameters(ft)...]
598-
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
598+
new_ft = LLVM.FunctionType(LLVM.return_type(ft), new_param_types)
599599
new_f = LLVM.Function(mod, fn, new_ft)
600600
LLVM.name!(parameters(new_f)[1], "state")
601601
linkage!(new_f, linkage(f))
@@ -627,7 +627,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
627627
# is all this even sound?
628628
typ = llvmtype(val)::LLVM.PointerType
629629
ft = eltype(typ)::LLVM.FunctionType
630-
new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)...])
630+
new_ft = LLVM.FunctionType(LLVM.return_type(ft), [T_state, parameters(ft)...])
631631
return const_bitcast(workmap[target], LLVM.PointerType(new_ft, addrspace(typ)))
632632
end
633633
elseif opcode(val) == LLVM.API.LLVMPtrToInt

src/optim.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ function cpu_features!(mod::LLVM.Module)
302302
# warn?
303303
false
304304
end
305-
has_fma = ConstantInt(return_type(ft), has_fma)
305+
has_fma = ConstantInt(LLVM.return_type(ft), has_fma)
306306

307307
# substitute all uses of the intrinsic with a constant
308308
materialized = LLVM.Value[]
@@ -342,7 +342,7 @@ function lower_gc_frame!(fun::LLVM.Function)
342342
if haskey(functions(mod), "julia.gc_alloc_obj")
343343
alloc_obj = functions(mod)["julia.gc_alloc_obj"]
344344
alloc_obj_ft = eltype(llvmtype(alloc_obj))
345-
T_prjlvalue = return_type(alloc_obj_ft)
345+
T_prjlvalue = LLVM.return_type(alloc_obj_ft)
346346
T_pjlvalue = convert(LLVMType, Any; ctx, allow_boxed=true)
347347

348348
for use in uses(alloc_obj)

src/ptx.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ function hide_unreachable!(fun::LLVM.Function)
337337
# couldn't find any other successor. this happens with functions
338338
# that only contain a single block, or when the block is dead.
339339
ft = eltype(llvmtype(fun))
340-
if return_type(ft) == LLVM.VoidType(ctx)
340+
if LLVM.return_type(ft) == LLVM.VoidType(ctx)
341341
# even though returning can lead to invalid control flow,
342342
# it mostly happens with functions that just throw,
343343
# and leaving the unreachable there would make the optimizer
@@ -411,7 +411,7 @@ function nvvm_reflect!(fun::LLVM.Function)
411411
haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false
412412
reflect_function = LLVM.functions(mod)[NVVM_REFLECT_FUNCTION]
413413
isdeclaration(reflect_function) || error("_reflect function should not have a body")
414-
reflect_typ = return_type(eltype(llvmtype(reflect_function)))
414+
reflect_typ = LLVM.return_type(eltype(llvmtype(reflect_function)))
415415
isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer")
416416

417417
to_remove = []

src/reflection.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,51 +39,58 @@ function highlight(io::IO, code, lexer)
3939
return
4040
end
4141

42+
#
43+
# Compat shims
44+
#
45+
46+
include("reflection_compat.jl")
4247

4348
#
4449
# code_* replacements
4550
#
4651

47-
code_lowered(@nospecialize(job::CompilerJob); kwargs...) =
48-
InteractiveUtils.code_lowered(job.source.f, job.source.tt; kwargs...)
49-
5052
@inline function typed_signature(@nospecialize(job::CompilerJob))
5153
u = Base.unwrap_unionall(job.source.tt)
5254
return Base.rewrap_unionall(Tuple{job.source.f, u.parameters...}, job.source.tt)
5355
end
5456

57+
code_lowered(@nospecialize(job::CompilerJob); kwargs...) =
58+
code_lowered_by_type(typed_signature(job); kwargs...)
59+
5560
function code_typed(@nospecialize(job::CompilerJob); interactive::Bool=false, kwargs...)
5661
# TODO: use the compiler driver to get the Julia method instance (we might rewrite it)
62+
tt = typed_signature(job)
5763
if interactive
5864
# call Cthulhu without introducing a dependency on Cthulhu
5965
mod = get(Base.loaded_modules, Cthulhu, nothing)
6066
mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
6167
interp = get_interpreter(job)
6268
descend_code_typed = getfield(mod, :descend_code_typed)
63-
descend_code_typed(job.source.f, job.source.tt; interp, kwargs...)
69+
descend_code_typed(tt; interp, kwargs...)
6470
elseif VERSION >= v"1.7-"
6571
interp = get_interpreter(job)
66-
Base.code_typed_by_type(typed_signature(job); interp, kwargs...)
72+
Base.code_typed_by_type(tt; interp, kwargs...)
6773
else
68-
Base.code_typed_by_type(typed_signature(job); kwargs...)
74+
Base.code_typed_by_type(tt; kwargs...)
6975
end
7076
end
7177

7278
function code_warntype(io::IO, @nospecialize(job::CompilerJob); interactive::Bool=false, kwargs...)
7379
# TODO: use the compiler driver to get the Julia method instance (we might rewrite it)
80+
tt = typed_signature(job)
7481
if interactive
7582
@assert io == stdout
7683
# call Cthulhu without introducing a dependency on Cthulhu
7784
mod = get(Base.loaded_modules, Cthulhu, nothing)
7885
mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.")
7986
interp = get_interpreter(job)
8087
descend_code_warntype = getfield(mod, :descend_code_warntype)
81-
descend_code_warntype(job.source.f, job.source.tt; interp, kwargs...)
88+
descend_code_warntype(tt; interp, kwargs...)
8289
elseif VERSION >= v"1.7-"
8390
interp = get_interpreter(job)
84-
InteractiveUtils.code_warntype(io, job.source.f, job.source.tt; interp, kwargs...)
91+
code_warntype_by_type(io, tt; interp, kwargs...)
8592
else
86-
InteractiveUtils.code_warntype(io, job.source.f, job.source.tt; kwargs...)
93+
code_warntype_by_type(io, tt; kwargs...)
8794
end
8895
end
8996
code_warntype(@nospecialize(job::CompilerJob); kwargs...) = code_warntype(stdout, job; kwargs...)

src/reflection_compat.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# The content of this file should be upstreamed to Julia proper
2+
3+
function method_instances(@nospecialize(tt::Type), world::UInt=Base.get_world_counter())
4+
return map(Core.Compiler.specialize_method, method_matches(tt; world))
5+
end
6+
7+
if VERSION >= v"1.7-"
8+
const hasgenerator = Base.hasgenerator
9+
else
10+
const hasgenerator = Base.isgenerated
11+
end
12+
13+
function code_lowered_by_type(@nospecialize(tt); generated::Bool=true, debuginfo::Symbol=:default)
14+
15+
debuginfo = Base.IRShow.debuginfo(debuginfo)
16+
if debuginfo !== :source && debuginfo !== :none
17+
throw(ArgumentError("'debuginfo' must be either :source or :none"))
18+
end
19+
return map(method_instances(tt)) do m
20+
if generated && hasgenerator(m)
21+
if Base.may_invoke_generator(m)
22+
return ccall(:jl_code_for_staged, Any, (Any,), m)::CodeInfo
23+
else
24+
error("Could not expand generator for `@generated` method ", m, ". ",
25+
"This can happen if the provided argument types (", t, ") are ",
26+
"not leaf types, but the `generated` argument is `true`.")
27+
end
28+
end
29+
code = Base.uncompressed_ir(m.def::Method)
30+
debuginfo === :none && Base.remove_linenums!(code)
31+
return code
32+
end
33+
end
34+
35+
function code_warntype_by_type(io::IO, @nospecialize(tt);
36+
debuginfo::Symbol=:default, optimize::Bool=false, kwargs...)
37+
debuginfo = Base.IRShow.debuginfo(debuginfo)
38+
lineprinter = Base.IRShow.__debuginfo[debuginfo]
39+
for (src, rettype) in Base.code_typed_by_type(tt; optimize, kwargs...)
40+
if !(src isa Core.CodeInfo)
41+
println(io, src)
42+
println(io, " failed to infer")
43+
continue
44+
end
45+
lambda_io::IOContext = io
46+
p = src.parent
47+
nargs::Int = 0
48+
if p isa Core.MethodInstance
49+
println(io, p)
50+
print(io, " from ")
51+
println(io, p.def)
52+
p.def isa Method && (nargs = p.def.nargs)
53+
if !isempty(p.sparam_vals)
54+
println(io, "Static Parameters")
55+
sig = p.def.sig
56+
warn_color = Base.warn_color() # more mild user notification
57+
for i = 1:length(p.sparam_vals)
58+
sig = sig::UnionAll
59+
name = sig.var.name
60+
val = p.sparam_vals[i]
61+
print_highlighted(io::IO, v::String, color::Symbol) =
62+
if highlighting[:warntype]
63+
Base.printstyled(io, v; color)
64+
else
65+
Base.print(io, v)
66+
end
67+
if val isa TypeVar
68+
if val.lb === Union{}
69+
print(io, " ", name, " <: ")
70+
print_highlighted(io, "$(val.ub)", warn_color)
71+
elseif val.ub === Any
72+
print(io, " ", sig.var.name, " >: ")
73+
print_highlighted(io, "$(val.lb)", warn_color)
74+
else
75+
print(io, " ")
76+
print_highlighted(io, "$(val.lb)", warn_color)
77+
print(io, " <: ", sig.var.name, " <: ")
78+
print_highlighted(io, "$(val.ub)", warn_color)
79+
end
80+
elseif val isa typeof(Vararg)
81+
print(io, " ", name, "::")
82+
print_highlighted(io, "Int", warn_color)
83+
else
84+
print(io, " ", sig.var.name, " = ")
85+
print_highlighted(io, "$(val)", :cyan) # show the "good" type
86+
end
87+
println(io)
88+
sig = sig.body
89+
end
90+
end
91+
end
92+
if src.slotnames !== nothing
93+
slotnames = Base.sourceinfo_slotnames(src)
94+
lambda_io = IOContext(lambda_io, :SOURCE_SLOTNAMES => slotnames)
95+
slottypes = src.slottypes
96+
nargs > 0 && println(io, "Arguments")
97+
for i = 1:length(slotnames)
98+
if i == nargs + 1
99+
println(io, "Locals")
100+
end
101+
print(io, " ", slotnames[i])
102+
if isa(slottypes, Vector{Any})
103+
InteractiveUtils.warntype_type_printer(io, slottypes[i], true)
104+
end
105+
println(io)
106+
end
107+
end
108+
print(io, "Body")
109+
InteractiveUtils.warntype_type_printer(io, rettype, true)
110+
println(io)
111+
@static if VERSION < v"1.7.0"
112+
Base.IRShow.show_ir(lambda_io, src, lineprinter(src), InteractiveUtils.warntype_type_printer)
113+
else
114+
irshow_config = Base.IRShow.IRShowConfig(lineprinter(src), InteractiveUtils.warntype_type_printer)
115+
Base.IRShow.show_ir(lambda_io, src, irshow_config)
116+
end
117+
println(io)
118+
end
119+
nothing
120+
end

src/rtlib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ function emit_function!(mod, @nospecialize(job::CompilerJob), f, method; ctx::Co
8181
optimize=false, libraries=false, ctx)
8282
ft = eltype(llvmtype(meta.entry))
8383
expected_ft = convert(LLVM.FunctionType, method; ctx=context(new_mod))
84-
if return_type(ft) != return_type(expected_ft)
85-
error("Invalid return type for runtime function '$(method.name)': expected $(return_type(expected_ft)), got $(return_type(ft))")
84+
if LLVM.return_type(ft) != LLVM.return_type(expected_ft)
85+
error("Invalid return type for runtime function '$(method.name)': expected $(LLVM.return_type(expected_ft)), got $(LLVM.return_type(ft))")
8686
end
8787

8888
# recent Julia versions include prototypes for all runtime functions, even if unused

src/spirv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ end
219219
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
220220
ctx = context(mod)
221221
ft = eltype(llvmtype(f)::LLVM.PointerType)::LLVM.FunctionType
222-
@compiler_assert return_type(ft) == LLVM.VoidType(ctx) job
222+
@compiler_assert LLVM.return_type(ft) == LLVM.VoidType(ctx) job
223223

224224
# find the byval parameters
225225
byval = BitVector(undef, length(parameters(ft)))
@@ -235,7 +235,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
235235
has_kernel_state = kernel_state_type(job) !== Nothing
236236
orig_ft = if has_kernel_state
237237
# the kernel state has been added here already, so strip the first parameter
238-
LLVM.FunctionType(return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
238+
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
239239
else
240240
ft
241241
end

src/validation.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,38 @@
22

33
export InvalidIRError
44

5-
function get_method_matches(@nospecialize(job::CompilerJob))
6-
tt = typed_signature(job)
7-
5+
function method_matches(@nospecialize(tt::Type{<:Tuple}); world=Base.get_world_counter())
86
ms = Core.MethodMatch[]
9-
for m in Base._methods_by_ftype(tt, -1, job.source.world)::Vector
7+
for m in Base._methods_by_ftype(tt, -1, world)::Vector
108
m = m::Core.MethodMatch
119
push!(ms, m)
1210
end
1311

1412
return ms
1513
end
1614

15+
function return_type(m::Core.MethodMatch;
16+
interp = Core.Compiler.NativeInterpreter(world))
17+
ty = Core.Compiler.typeinf_type(interp, m.method, m.spec_types, m.sparams)
18+
return something(ty, Any)
19+
end
20+
1721

1822
function check_method(@nospecialize(job::CompilerJob))
1923
isa(job.source.f, Core.Builtin) && throw(KernelError(job, "function is not a generic function"))
2024

2125
# get the method
22-
ms = get_method_matches(job)
26+
world = job.source.world
27+
ms = method_matches(typed_signature(job); world)
2328
isempty(ms) && throw(KernelError(job, "no method found"))
2429
length(ms)!=1 && throw(KernelError(job, "no unique matching method"))
2530

2631
# kernels can't return values
2732
if job.source.kernel
2833
cache = ci_cache(job)
2934
mt = method_table(job)
30-
interp = GPUInterpreter(cache, mt, job.source.world)
31-
m = only(ms)
32-
ty = Core.Compiler.typeinf_type(interp, m.method, m.spec_types, m.sparams)
33-
rt = something(ty, Any)
35+
interp = GPUInterpreter(cache, mt, world)
36+
rt = return_type(only(ms); interp)
3437

3538
if rt != Nothing
3639
throw(KernelError(job, "kernel returns a value of type `$rt`",

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
2+
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
23
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
34
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
5+
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
46
SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361"
57
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
68
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)