Skip to content

Commit 9af6e00

Browse files
committed
Make GPUInterpreter extensible
Currently Enzyme uses it's own AbstractInterpreter, in particular to handle inlining blocking of functions with custom rules and to handle nested autodiff operations. - [ ] Create a version of Enzyme with this - [ ] Support a version of `gpuc.deferred(meta)`
1 parent 828ee63 commit 9af6e00

File tree

5 files changed

+210
-24
lines changed

5 files changed

+210
-24
lines changed

src/driver.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
## deferred compilation
4343

4444
"""
45-
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
45+
var"gpuc.deferred"(meta, f, args...)::Ptr{Cvoid}
4646
4747
As if we were to call `f(args...)` but instead we are
4848
putting down a marker and return a function pointer to later
@@ -199,18 +199,19 @@ const __llvm_initialized = Ref(false)
199199
return val
200200
end
201201

202-
worklist = Dict{Any, Vector{LLVM.CallInst}}()
202+
worklist = Dict{MethodInstance, Vector{LLVM.CallInst}}()
203203
for use in uses(dyn_marker)
204204
# decode the call
205205
call = user(use)::LLVM.CallInst
206-
dyn_mi_inst = find_base_object(operands(call)[1])
206+
dyn_mi_inst = find_base_object(operands(call)[2])
207207
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
208208
dyn_mi = Base.unsafe_pointer_to_objref(
209-
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
209+
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))::MethodInstance
210210
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
211211
end
212212

213213
for dyn_mi in keys(worklist)
214+
# TODO: Should compiled become Edge[]
214215
dyn_fn_name = compiled[dyn_mi].specfunc
215216
dyn_fn = functions(ir)[dyn_fn_name]
216217

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
181181
# provide a specific interpreter to use.
182182
if VERSION >= v"1.11.0-DEV.1552"
183183
get_interpreter(@nospecialize(job::CompilerJob)) =
184-
GPUInterpreter(job.world; method_table=method_table(job),
184+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
185185
token=ci_cache_token(job), inf_params=inference_params(job),
186186
opt_params=optimization_params(job))
187187
else
188188
get_interpreter(@nospecialize(job::CompilerJob)) =
189-
GPUInterpreter(job.world; method_table=method_table(job),
189+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
190190
code_cache=ci_cache(job), inf_params=inference_params(job),
191191
opt_params=optimization_params(job))
192192
end

src/jlgen.jl

Lines changed: 189 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ else
320320
end
321321

322322
struct GPUInterpreter <: CC.AbstractInterpreter
323+
meta::Any
323324
world::UInt
324325
method_table::GPUMethodTableView
325326

@@ -336,6 +337,7 @@ end
336337

337338
@static if HAS_INTEGRATED_CACHE
338339
function GPUInterpreter(world::UInt=Base.get_world_counter();
340+
meta = nothing,
339341
method_table::MTType,
340342
token::Any,
341343
inf_params::CC.InferenceParams,
@@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
345347
method_table = get_method_table_view(world, method_table)
346348
inf_cache = Vector{CC.InferenceResult}()
347349

348-
return GPUInterpreter(world, method_table,
350+
return GPUInterpreter(meta, world, method_table,
349351
token, inf_cache,
350352
inf_params, opt_params)
351353
end
352354

353355
function GPUInterpreter(interp::GPUInterpreter;
356+
meta=interp.meta,
354357
world::UInt=interp.world,
355358
method_table::GPUMethodTableView=interp.method_table,
356359
token::Any=interp.token,
357360
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
358361
inf_params::CC.InferenceParams=interp.inf_params,
359362
opt_params::CC.OptimizationParams=interp.opt_params)
360-
return GPUInterpreter(world, method_table,
363+
return GPUInterpreter(meta, world, method_table,
361364
token, inf_cache,
362365
inf_params, opt_params)
363366
end
364367

365368
else
366369

367370
function GPUInterpreter(world::UInt=Base.get_world_counter();
371+
meta=nothing,
368372
method_table::MTType,
369373
code_cache::CodeCache,
370374
inf_params::CC.InferenceParams,
@@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
374378
method_table = get_method_table_view(world, method_table)
375379
inf_cache = Vector{CC.InferenceResult}()
376380

377-
return GPUInterpreter(world, method_table,
381+
return GPUInterpreter(meta, world, method_table,
378382
code_cache, inf_cache,
379383
inf_params, opt_params)
380384
end
381385

382386
function GPUInterpreter(interp::GPUInterpreter;
387+
meta=interp.meta,
383388
world::UInt=interp.world,
384389
method_table::GPUMethodTableView=interp.method_table,
385390
code_cache::CodeCache=interp.code_cache,
386391
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
387392
inf_params::CC.InferenceParams=interp.inf_params,
388393
opt_params::CC.OptimizationParams=interp.opt_params)
389-
return GPUInterpreter(world, method_table,
394+
return GPUInterpreter(meta, world, method_table,
390395
code_cache, inf_cache,
391396
inf_params, opt_params)
392397
end
@@ -437,28 +442,76 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
437442
end
438443

439444

445+
within_gpucompiler() = false
446+
440447
## deferred compilation
441448

442449
struct DeferredCallInfo <: CC.CallInfo
450+
meta::Any
443451
rt::DataType
444452
info::CC.CallInfo
445453
end
446454

447455
# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
448-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
449-
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
450-
max_methods::Int = CC.get_max_methods(interp, f, sv))
456+
# default implementation, extensible through meta argument.
457+
# XXX: (or should we dispatch on `f`)?
458+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
459+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
460+
max_methods::Int = CC.get_max_methods(interp, f, sv))
451461
(; fargs, argtypes) = arginfo
452462
if f === var"gpuc.deferred"
453-
argvec = argtypes[2:end]
463+
argvec = argtypes[3:end]
454464
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
455-
callinfo = DeferredCallInfo(call.rt, call.info)
465+
metaT = argtypes[2]
466+
meta = CC.singleton_type(metaT)
467+
if meta === nothing
468+
if metaT isa Core.Const
469+
meta = metaT.val
470+
else
471+
# meta is not a singleton type result may depend on runtime configuration
472+
add_remark!(interp, sv, "Skipped gpuc.deferred since meta not constant")
473+
@static if VERSION < v"1.11.0-"
474+
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
475+
else
476+
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
477+
end
478+
end
479+
end
480+
481+
callinfo = DeferredCallInfo(meta, call.rt, call.info)
456482
@static if VERSION < v"1.11.0-"
457483
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
458484
else
459485
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
460486
end
487+
elseif f === within_gpucompiler
488+
if length(argtypes) != 1
489+
@static if VERSION < v"1.11.0-"
490+
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
491+
else
492+
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
493+
end
494+
end
495+
@static if VERSION < v"1.11.0-"
496+
return CC.CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure())
497+
else
498+
return CC.CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),)
499+
end
461500
end
501+
return nothing
502+
end
503+
504+
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
505+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
506+
max_methods::Int = CC.get_max_methods(interp, f, sv))
507+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
508+
if candidate === nothing && interp.meta !== nothing
509+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
510+
end
511+
if candidate !== nothing
512+
return candidate
513+
end
514+
462515
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
463516
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
464517
max_methods::Int)
@@ -485,32 +538,39 @@ function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int
485538
args = Any[
486539
"extern gpuc.lookup",
487540
Ptr{Cvoid},
488-
Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
541+
Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
489542
0,
490543
QuoteNode(:llvmcall),
544+
info.meta,
491545
case.invoke,
492-
stmt.args[2:end]...
546+
stmt.args[3:end]...
493547
]
494548
stmt.head = :foreigncall
495549
stmt.args = args
496550
return nothing
497551
end
498552

553+
struct Edge
554+
meta::Any
555+
mi::MethodInstance
556+
end
557+
499558
struct DeferredEdges
500-
edges::Vector{MethodInstance}
559+
edges::Vector{Edge}
501560
end
502561

503562
function find_deferred_edges(ir::CC.IRCode)
504-
edges = MethodInstance[]
563+
edges = Edge[]
505564
# XXX: can we add this instead in handle_call?
506565
for stmt in ir.stmts
507566
inst = stmt[:inst]
508567
inst isa Expr || continue
509568
expr = inst::Expr
510569
if expr.head === :foreigncall &&
511570
expr.args[1] == "extern gpuc.lookup"
512-
deferred_mi = expr.args[6]
513-
push!(edges, deferred_mi)
571+
deferred_meta = expr.args[6]
572+
deferred_mi = expr.args[7]
573+
push!(edges, Edge(deferred_meta, deferred_mi))
514574
end
515575
end
516576
unique!(edges)
@@ -542,6 +602,116 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
542602
end
543603
end
544604

605+
import .CC: CallInfo
606+
struct NoInlineCallInfo <: CallInfo
607+
info::CallInfo # wrapped call
608+
tt::Any # ::Type
609+
kind::Symbol
610+
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
611+
new(info, tt, kind)
612+
end
613+
614+
CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
615+
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
616+
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
617+
struct AlwaysInlineCallInfo <: CallInfo
618+
info::CallInfo # wrapped call
619+
tt::Any # ::Type
620+
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
621+
end
622+
623+
CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
624+
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
625+
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
626+
627+
628+
function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo)
629+
return nothing
630+
end
631+
632+
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
633+
function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo,
634+
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
635+
ret = @invoke CC.abstract_call_gf_by_type(interp::CC.AbstractInterpreter, f::Any, arginfo::ArgInfo,
636+
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)
637+
638+
callinfo = nothing
639+
if interp.meta !== nothing
640+
callinfo = inlining_handler(interp.meta, interp, atype, ret.info)
641+
end
642+
if callinfo === nothing
643+
callinfo = inlining_handler(nothing, interp, atype, ret.info)
644+
end
645+
if callinfo === nothing
646+
callinfo = ret.info
647+
end
648+
649+
@static if VERSION v"1.11-"
650+
return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
651+
else
652+
return CC.CallMeta(ret.rt, ret.effects, callinfo)
653+
end
654+
end
655+
656+
@static if VERSION < v"1.12.0-DEV.45"
657+
let # overload `inlining_policy`
658+
@static if VERSION v"1.11.0-DEV.879"
659+
sigs_ex = :(
660+
interp::GPUInterpreter,
661+
@nospecialize(src),
662+
@nospecialize(info::CC.CallInfo),
663+
stmt_flag::UInt32,
664+
)
665+
args_ex = :(
666+
interp::CC.AbstractInterpreter,
667+
src::Any,
668+
info::CC.CallInfo,
669+
stmt_flag::UInt32,
670+
)
671+
else
672+
sigs_ex = :(
673+
interp::GPUInterpreter,
674+
@nospecialize(src),
675+
@nospecialize(info::CC.CallInfo),
676+
stmt_flag::UInt8,
677+
mi::MethodInstance,
678+
argtypes::Vector{Any},
679+
)
680+
args_ex = :(
681+
interp::CC.AbstractInterpreter,
682+
src::Any,
683+
info::CC.CallInfo,
684+
stmt_flag::UInt8,
685+
mi::MethodInstance,
686+
argtypes::Vector{Any},
687+
)
688+
end
689+
@eval function CC.inlining_policy($(sigs_ex.args...))
690+
if info isa NoInlineCallInfo
691+
@safe_debug "Blocking inlining" info.tt info.kind
692+
return nothing
693+
elseif info isa AlwaysInlineCallInfo
694+
@safe_debug "Forcing inlining for" info.tt
695+
return src
696+
end
697+
return @invoke CC.inlining_policy($(args_ex.args...))
698+
end
699+
end
700+
else
701+
function CC.src_inlining_policy(interp::GPUInterpreter,
702+
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
703+
704+
if info isa NoInlineCallInfo
705+
@safe_debug "Blocking inlining" info.tt info.kind
706+
return false
707+
elseif info isa AlwaysInlineCallInfo
708+
@safe_debug "Forcing inlining for" info.tt
709+
return true
710+
end
711+
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src, info::CC.CallInfo, stmt_flag::UInt32)
712+
end
713+
end
714+
545715

546716
## world view of the cache
547717
using Core.Compiler: WorldView
@@ -704,7 +874,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
704874
source = pop!(worklist)
705875
haskey(compiled, source) && continue # We have fulfilled the request already
706876
# Create a new compiler job for this edge, reusing the config settings from the inital one
707-
job2 = CompilerJob(source, job.config)
877+
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
708878
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
709879
append!(worklist, outstanding) # merge worklist with new outstanding edges
710880
@assert context(llvm_mod) == context(llvm_mod2)
@@ -844,7 +1014,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDi
8441014
end
8451015
end
8461016
if edges !== nothing
847-
for deferred_mi in (edges::DeferredEdges).edges
1017+
for edge in (edges::DeferredEdges).edges
1018+
# TODO
1019+
deferred_mi = edge.mi
8481020
if !haskey(compiled, deferred_mi)
8491021
push!(outstanding, deferred_mi)
8501022
end

test/native_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ end
166166
@testset "deferred" begin
167167
@gensym child kernel unrelated
168168
@eval @noinline $child(i) = i
169-
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i)
169+
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"(nothing, $child, i)
170170

171171
# smoke test
172172
job, _ = Native.create_job(eval(kernel), (Int64,))

0 commit comments

Comments
 (0)