Skip to content

Commit 3a73ef5

Browse files
wsmosesvchuravy
andauthored
Use function type in fspec instead of value (#320)
Co-authored-by: Valentin Churavy <[email protected]>
1 parent e810ec6 commit 3a73ef5

File tree

11 files changed

+94
-39
lines changed

11 files changed

+94
-39
lines changed

examples/kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntim
1616
kernel() = nothing
1717

1818
function main()
19-
source = FunctionSpec(kernel)
19+
source = FunctionSpec(typeof(kernel))
2020
target = NativeCompilerTarget()
2121
params = TestCompilerParams()
2222
job = CompilerJob(target, source, params)

src/driver.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,14 @@ end
147147
@timeit_debug to "Julia front-end" begin
148148

149149
# get the method instance
150-
meth = which(job.source.f, job.source.tt)
151-
sig = Base.signature_type(job.source.f, job.source.tt)::Type
150+
sig = typed_signature(job)
151+
meth = which(sig)
152+
152153
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
153154
(Any, Any), sig, meth.sig)::Core.SimpleVector
155+
154156
meth = Base.func_for_method_checked(meth, ti, env)
157+
155158
method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
156159
(Any, Any, Any, UInt), meth, ti, env, job.source.world)
157160

src/interface.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ export FunctionSpec
6262
# what we'll be compiling
6363

6464
struct FunctionSpec{F,TT}
65-
f::F
65+
f::Type{F}
6666
tt::Type{TT}
6767
kernel::Bool
6868
name::Union{Nothing,String}
@@ -85,8 +85,11 @@ end
8585
# world age intersection when querying the compilation cache. once we do, callers
8686
# should probably provide the world age of the calling code (!= the current world age)
8787
# so that querying the cache from, e.g. `cufuncton` is a fully static operation.
88+
FunctionSpec(f::Type, tt=Tuple{}, kernel=true, name=nothing, world_age=-1%UInt) =
89+
FunctionSpec{f,tt}(f, tt, kernel, name, world_age)
90+
8891
FunctionSpec(f, tt=Tuple{}, kernel=true, name=nothing, world_age=-1%UInt) =
89-
FunctionSpec{typeof(f),tt}(f, tt, kernel, name, world_age)
92+
FunctionSpec(Core.Typeof(f), tt, kernel, name, world_age)
9093

9194
function Base.getproperty(@nospecialize(spec::FunctionSpec), sym::Symbol)
9295
if sym == :world

src/irgen.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ end
302302
end
303303

304304
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
305-
source_sig = Base.signature_type(job.source.f, job.source.tt)::Type
305+
source_sig = typed_signature(job)
306+
306307
source_types = [source_sig.parameters...]
307308

308309
codegen_types = parameters(codegen_ft)

src/reflection.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ end
4747
code_lowered(@nospecialize(job::CompilerJob); kwargs...) =
4848
InteractiveUtils.code_lowered(job.source.f, job.source.tt; kwargs...)
4949

50+
@inline function typed_signature(@nospecialize(job::CompilerJob))
51+
u = Base.unwrap_unionall(job.source.tt)
52+
return Base.rewrap_unionall(Tuple{job.source.f, u.parameters...}, job.source.tt)
53+
end
54+
5055
function code_typed(@nospecialize(job::CompilerJob); interactive::Bool=false, kwargs...)
5156
# TODO: use the compiler driver to get the Julia method instance (we might rewrite it)
5257
if interactive
@@ -58,9 +63,9 @@ function code_typed(@nospecialize(job::CompilerJob); interactive::Bool=false, kw
5863
descend_code_typed(job.source.f, job.source.tt; interp, kwargs...)
5964
elseif VERSION >= v"1.7-"
6065
interp = get_interpreter(job)
61-
InteractiveUtils.code_typed(job.source.f, job.source.tt; interp, kwargs...)
66+
Base.code_typed_by_type(typed_signature(job); interp, kwargs...)
6267
else
63-
InteractiveUtils.code_typed(job.source.f, job.source.tt; kwargs...)
68+
Base.code_typed_by_type(typed_signature(job); kwargs...)
6469
end
6570
end
6671

src/rtlib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function build_runtime(@nospecialize(job::CompilerJob); ctx)
112112

113113
# the compiler job passed into here is identifies the job that requires the runtime.
114114
# derive a job that represents the runtime itself (notably with kernel=false).
115-
source = FunctionSpec(identity, Tuple{Nothing}, false, nothing, job.source.world_age)
115+
source = FunctionSpec(typeof(identity), Tuple{Nothing}, false, nothing, job.source.world_age)
116116
job = CompilerJob(job.target, source, job.params)
117117

118118
for method in values(Runtime.methods)
@@ -122,7 +122,7 @@ function build_runtime(@nospecialize(job::CompilerJob); ctx)
122122
else
123123
method.def
124124
end
125-
emit_function!(mod, job, def, method; ctx)
125+
emit_function!(mod, job, typeof(def), method; ctx)
126126
end
127127

128128
# we cannot optimize the runtime library, because the code would then be optimized again

src/validation.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,36 @@
22

33
export InvalidIRError
44

5+
function get_method_matches(@nospecialize(job::CompilerJob))
6+
tt = typed_signature(job)
7+
8+
ms = Core.MethodMatch[]
9+
for m in Base._methods_by_ftype(tt, -1, job.source.world)::Vector
10+
m = m::Core.MethodMatch
11+
push!(ms, m)
12+
end
13+
14+
return ms
15+
end
16+
17+
518
function check_method(@nospecialize(job::CompilerJob))
619
isa(job.source.f, Core.Builtin) && throw(KernelError(job, "function is not a generic function"))
720

821
# get the method
9-
ms = Base.methods(job.source.f, job.source.tt)
22+
ms = get_method_matches(job)
1023
isempty(ms) && throw(KernelError(job, "no method found"))
1124
length(ms)!=1 && throw(KernelError(job, "no unique matching method"))
12-
m = first(ms)
1325

1426
# kernels can't return values
1527
if job.source.kernel
1628
cache = ci_cache(job)
1729
mt = method_table(job)
1830
interp = GPUInterpreter(cache, mt, job.source.world)
19-
@static if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION v"1.9.0-DEV.190"
20-
# https://github.com/JuliaLang/julia/pull/44515
21-
rt = Base.return_types(job.source.f, job.source.tt; interp)[1]
22-
else
23-
rt = Base.return_types(job.source.f, job.source.tt, interp)[1]
24-
end
31+
m = only(ms)
32+
ty = Core.Compiler.typeinf_type(interp, m.method, m.spec_types, m.sparams)
33+
rt = something(ty, Any)
34+
2535
if rt != Nothing
2636
throw(KernelError(job, "kernel returns a value of type `$rt`",
2737
"""Make sure your kernel function ends in `return`, `return nothing` or `nothing`.
@@ -61,7 +71,9 @@ end
6171
function check_invocation(@nospecialize(job::CompilerJob))
6272
# make sure any non-isbits arguments are unused
6373
real_arg_i = 0
64-
sig = Base.signature_type(job.source.f, job.source.tt)::Type
74+
75+
sig = typed_signature(job)
76+
6577
for (arg_i,dt) in enumerate(sig.parameters)
6678
isghosttype(dt) && continue
6779
Core.Compiler.isconstType(dt) && continue

test/definitions/gcn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ end
88
# create a GCN-based test compiler, and generate reflection methods for it
99

1010
function gcn_job(@nospecialize(func), @nospecialize(types); kernel::Bool=false, kwargs...)
11-
source = FunctionSpec(func, Base.to_tuple_type(types), kernel)
11+
source = FunctionSpec(typeof(func), Base.to_tuple_type(types), kernel)
1212
target = GCNCompilerTarget(dev_isa="gfx900")
1313
params = TestCompilerParams()
1414
CompilerJob(target, source, params), kwargs

test/definitions/native.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ end
1818

1919
GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) = method_table
2020

21-
function native_job(@nospecialize(func), @nospecialize(types); kernel::Bool=false, entry_abi=:specfunc, kwargs...)
22-
source = FunctionSpec(func, Base.to_tuple_type(types), kernel)
21+
function native_job(@nospecialize(f_type), @nospecialize(types); kernel::Bool=false, entry_abi=:specfunc, kwargs...)
22+
source = FunctionSpec(f_type, Base.to_tuple_type(types), kernel)
2323
target = NativeCompilerTarget(always_inline=true)
2424
params = TestCompilerParams()
2525
CompilerJob(target, source, params, entry_abi), kwargs
@@ -249,8 +249,8 @@ module LazyCodegen
249249
end
250250

251251
import GPUCompiler: deferred_codegen_jobs
252-
@generated function deferred_codegen(::Val{f}, ::Val{tt}) where {f,tt}
253-
job, _ = native_job(f, tt)
252+
@generated function deferred_codegen(f::F, ::Val{tt}) where {F,tt}
253+
job, _ = native_job(F, tt)
254254

255255
addr = get_trampoline(job)
256256
trampoline = pointer(addr)
@@ -265,7 +265,7 @@ module LazyCodegen
265265
end
266266
end
267267

268-
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, args::Vararg{Any, N}) where {T, RT, N}
268+
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
269269
argtt = tt.parameters[1]
270270
rettype = rt.parameters[1]
271271
argtypes = DataType[argtt.parameters...]
@@ -276,8 +276,27 @@ module LazyCodegen
276276
before = :()
277277
after = :(ret)
278278

279+
279280
# Note this follows: emit_call_specfun_other
280281
JuliaContext() do ctx
282+
283+
if !isghosttype(F) && !Core.Compiler.isconstType(F)
284+
isboxed = GPUCompiler.deserves_argbox(F)
285+
argexpr = :(func)
286+
if isboxed
287+
push!(ccall_types, Any)
288+
else
289+
et = convert(LLVMType, func; ctx)
290+
if isa(et, LLVM.SequentialType) # et->isAggregateType
291+
push!(ccall_types, Ptr{F})
292+
argexpr = Expr(:call, GlobalRef(Base, :Ref), argexpr)
293+
else
294+
push!(ccall_types, F)
295+
end
296+
end
297+
push!(argexprs, argexpr)
298+
end
299+
281300
T_jlvalue = LLVM.StructType(LLVMType[],;ctx)
282301
T_prjlvalue = LLVM.PointerType(T_jlvalue, #= AddressSpace::Tracked =# 10)
283302

@@ -330,7 +349,7 @@ module LazyCodegen
330349
@inline function call_delayed(f::F, args...) where F
331350
tt = Tuple{map(Core.Typeof, args)...}
332351
rt = Core.Compiler.return_type(f, tt)
333-
ptr = deferred_codegen(Val(f), Val(tt))
334-
abi_call(ptr, rt, tt, args...)
352+
ptr = deferred_codegen(f, Val(tt))
353+
abi_call(ptr, rt, tt, f, args...)
335354
end
336355
end

test/native.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ include("definitions/native.jl")
1111

1212
(CI, rt) = native_code_typed(MyCallable(), (Int, Int), kernel=false)[1]
1313
@test CI.slottypes[1] == Core.Compiler.Const(MyCallable())
14+
15+
(CI, rt) = native_code_typed(typeof(MyCallable()), (Int, Int), kernel=false)[1]
16+
@test CI.slottypes[1] == Core.Compiler.Const(MyCallable())
1417
end
1518

1619
@testset "Compilation database" begin
@@ -339,11 +342,7 @@ end
339342

340343
# Test ABI removal
341344
ir = sprint(io->native_code_llvm(io, call_real, Tuple{ComplexF64}))
342-
if VERSION < v"1.8-" || v"1.8-beta2" <= VERSION < v"1.9-" || VERSION v"1.9.0-DEV.190"
343-
@test !occursin("alloca", ir)
344-
else
345-
@test_broken !occursin("alloca", ir)
346-
end
345+
@test !occursin("alloca", ir)
347346

348347
ghostly_identity(x, y) = y
349348
@test call_delayed(ghostly_identity, nothing, 1) == 1
@@ -358,6 +357,20 @@ end
358357
throws(arr, i) = arr[i]
359358
@test call_delayed(throws, [1], 1) == 1
360359
@test_throws BoundsError call_delayed(throws, [1], 0)
360+
361+
struct Closure
362+
x::Int64
363+
end
364+
(c::Closure)(b) = c.x+b
365+
366+
@test call_delayed(Closure(3), 5) == 8
367+
368+
struct Closure2
369+
x::Integer
370+
end
371+
(c::Closure2)(b) = c.x+b
372+
373+
@test call_delayed(Closure2(3), 5) == 8
361374
end
362375

363376
############################################################################################

0 commit comments

Comments
 (0)