Skip to content

Commit 64e5353

Browse files
committed
Make GPUInterpreter extensible
Currently Enzyme uses it's own AbstractInterpreter, in particular to handle inlining blocking of functions with custom rules and to handle nested autodiff operations. - [ ] Create a version of Enzyme with this - [ ] Support a version of `gpuc.deferred(meta)`
1 parent dfd5c35 commit 64e5353

File tree

2 files changed

+126
-10
lines changed

2 files changed

+126
-10
lines changed

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
181181
# provide a specific interpreter to use.
182182
if VERSION >= v"1.11.0-DEV.1552"
183183
get_interpreter(@nospecialize(job::CompilerJob)) =
184-
GPUInterpreter(job.world; method_table=method_table(job),
184+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
185185
token=ci_cache_token(job), inf_params=inference_params(job),
186186
opt_params=optimization_params(job))
187187
else
188188
get_interpreter(@nospecialize(job::CompilerJob)) =
189-
GPUInterpreter(job.world; method_table=method_table(job),
189+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
190190
code_cache=ci_cache(job), inf_params=inference_params(job),
191191
opt_params=optimization_params(job))
192192
end

src/jlgen.jl

Lines changed: 124 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ else
320320
end
321321

322322
struct GPUInterpreter <: CC.AbstractInterpreter
323+
meta::Any
323324
world::UInt
324325
method_table::GPUMethodTableView
325326

@@ -336,6 +337,7 @@ end
336337

337338
@static if HAS_INTEGRATED_CACHE
338339
function GPUInterpreter(world::UInt=Base.get_world_counter();
340+
meta = nothing,
339341
method_table::MTType,
340342
token::Any,
341343
inf_params::CC.InferenceParams,
@@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
345347
method_table = get_method_table_view(world, method_table)
346348
inf_cache = Vector{CC.InferenceResult}()
347349

348-
return GPUInterpreter(world, method_table,
350+
return GPUInterpreter(meta, world, method_table,
349351
token, inf_cache,
350352
inf_params, opt_params)
351353
end
352354

353355
function GPUInterpreter(interp::GPUInterpreter;
356+
meta=interp.meta,
354357
world::UInt=interp.world,
355358
method_table::GPUMethodTableView=interp.method_table,
356359
token::Any=interp.token,
357360
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
358361
inf_params::CC.InferenceParams=interp.inf_params,
359362
opt_params::CC.OptimizationParams=interp.opt_params)
360-
return GPUInterpreter(world, method_table,
363+
return GPUInterpreter(meta, world, method_table,
361364
token, inf_cache,
362365
inf_params, opt_params)
363366
end
364367

365368
else
366369

367370
function GPUInterpreter(world::UInt=Base.get_world_counter();
371+
meta=nothing,
368372
method_table::MTType,
369373
code_cache::CodeCache,
370374
inf_params::CC.InferenceParams,
@@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
374378
method_table = get_method_table_view(world, method_table)
375379
inf_cache = Vector{CC.InferenceResult}()
376380

377-
return GPUInterpreter(world, method_table,
381+
return GPUInterpreter(meta, world, method_table,
378382
code_cache, inf_cache,
379383
inf_params, opt_params)
380384
end
381385

382386
function GPUInterpreter(interp::GPUInterpreter;
387+
meta=interp.meta,
383388
world::UInt=interp.world,
384389
method_table::GPUMethodTableView=interp.method_table,
385390
code_cache::CodeCache=interp.code_cache,
386391
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
387392
inf_params::CC.InferenceParams=interp.inf_params,
388393
opt_params::CC.OptimizationParams=interp.opt_params)
389-
return GPUInterpreter(world, method_table,
394+
return GPUInterpreter(meta, world, method_table,
390395
code_cache, inf_cache,
391396
inf_params, opt_params)
392397
end
@@ -445,9 +450,11 @@ struct DeferredCallInfo <: CC.CallInfo
445450
end
446451

447452
# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
448-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
449-
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
450-
max_methods::Int = CC.get_max_methods(interp, f, sv))
453+
# default implementation, extensible through meta argument.
454+
# XXX: (or should we dispatch on `f`)?
455+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
456+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
457+
max_methods::Int = CC.get_max_methods(interp, f, sv))
451458
(; fargs, argtypes) = arginfo
452459
if f === var"gpuc.deferred"
453460
argvec = argtypes[2:end]
@@ -459,6 +466,20 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
459466
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
460467
end
461468
end
469+
return nothing
470+
end
471+
472+
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
473+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
474+
max_methods::Int = CC.get_max_methods(interp, f, sv))
475+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
476+
if candidate === nothing && interp.meta !== nothing
477+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
478+
end
479+
if candidate !== nothing
480+
return candidate
481+
end
482+
462483
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
463484
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
464485
max_methods::Int)
@@ -542,6 +563,101 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
542563
end
543564
end
544565

566+
import .CC: CallInfo
567+
struct NoInlineCallInfo <: CallInfo
568+
info::CallInfo # wrapped call
569+
tt::Any # ::Type
570+
kind::Symbol
571+
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
572+
new(info, tt, kind)
573+
end
574+
575+
CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
576+
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
577+
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
578+
struct AlwaysInlineCallInfo <: CallInfo
579+
info::CallInfo # wrapped call
580+
tt::Any # ::Type
581+
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
582+
end
583+
584+
CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
585+
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
586+
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
587+
588+
589+
function inlining_handler(meta::Nothing, interp::GPUCompiler, @nospecialize(atype), callinfo)
590+
return nothing
591+
end
592+
593+
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
594+
function CC.abstract_call_gf_by_type(interp::GPUCompiler, @nospecialize(f), arginfo::ArgInfo,
595+
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
596+
ret = @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, arginfo::ArgInfo,
597+
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)
598+
599+
callinfo = nothing
600+
if interp.meta !== nothing
601+
callinfo = inlining_handler(interp.meta, interp, atype, ret.callinfo)
602+
end
603+
if callinfo === nothing
604+
callinfo = inlining_handler(nothing, interp, atype, ret.callinfo)
605+
end
606+
if callinfo === nothing
607+
callinfo = ret.callinfo
608+
end
609+
610+
@static if VERSION v"1.11-"
611+
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
612+
else
613+
return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo)
614+
end
615+
end
616+
617+
let # overload `inlining_policy`
618+
@static if VERSION v"1.11.0-DEV.879"
619+
sigs_ex = :(
620+
interp::GPUInterpreter,
621+
@nospecialize(src),
622+
@nospecialize(info::Core.Compiler.CallInfo),
623+
stmt_flag::UInt32,
624+
)
625+
args_ex = :(
626+
interp::AbstractInterpreter,
627+
src::Any,
628+
info::Core.Compiler.CallInfo,
629+
stmt_flag::UInt32,
630+
)
631+
else
632+
sigs_ex = :(
633+
interp::GPUInterpreter,
634+
@nospecialize(src),
635+
@nospecialize(info::Core.Compiler.CallInfo),
636+
stmt_flag::UInt8,
637+
mi::MethodInstance,
638+
argtypes::Vector{Any},
639+
)
640+
args_ex = :(
641+
interp::AbstractInterpreter,
642+
src::Any,
643+
info::Core.Compiler.CallInfo,
644+
stmt_flag::UInt8,
645+
mi::MethodInstance,
646+
argtypes::Vector{Any},
647+
)
648+
end
649+
@eval function Core.Compiler.inlining_policy($(sigs_ex.args...))
650+
if info isa NoInlineCallInfo
651+
@safe_debug "Blocking inlining" info.tt info.kind
652+
return nothing
653+
elseif info isa AlwaysInlineCallInfo
654+
@safe_debug "Forcing inlining for" info.tt
655+
return src
656+
end
657+
return @invoke Core.Compiler.inlining_policy($(args_ex.args...))
658+
end
659+
end
660+
545661

546662
## world view of the cache
547663
using Core.Compiler: WorldView
@@ -704,7 +820,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
704820
source = pop!(worklist)
705821
haskey(compiled, source) && continue # We have fulfilled the request already
706822
# Create a new compiler job for this edge, reusing the config settings from the inital one
707-
job2 = CompilerJob(source, job.config)
823+
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
708824
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
709825
append!(worklist, outstanding) # merge worklist with new outstanding edges
710826
@assert context(llvm_mod) == context(llvm_mod2)

0 commit comments

Comments
 (0)