Skip to content

Commit f647ad6

Browse files
committed
Make GPUInterpreter extensible
Currently Enzyme uses it's own AbstractInterpreter, in particular to handle inlining blocking of functions with custom rules and to handle nested autodiff operations. - [ ] Create a version of Enzyme with this - [ ] Support a version of `gpuc.deferred(meta)`
1 parent e9d1372 commit f647ad6

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
181181
# provide a specific interpreter to use.
182182
if VERSION >= v"1.11.0-DEV.1552"
183183
get_interpreter(@nospecialize(job::CompilerJob)) =
184-
GPUInterpreter(job.world; method_table=method_table(job),
184+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
185185
token=ci_cache_token(job), inf_params=inference_params(job),
186186
opt_params=optimization_params(job))
187187
else
188188
get_interpreter(@nospecialize(job::CompilerJob)) =
189-
GPUInterpreter(job.world; method_table=method_table(job),
189+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
190190
code_cache=ci_cache(job), inf_params=inference_params(job),
191191
opt_params=optimization_params(job))
192192
end

src/jlgen.jl

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ else
320320
end
321321

322322
struct GPUInterpreter <: CC.AbstractInterpreter
323+
meta::Any
323324
world::UInt
324325
method_table::GPUMethodTableView
325326

@@ -336,6 +337,7 @@ end
336337

337338
@static if HAS_INTEGRATED_CACHE
338339
function GPUInterpreter(world::UInt=Base.get_world_counter();
340+
meta = nothing,
339341
method_table::MTType,
340342
token::Any,
341343
inf_params::CC.InferenceParams,
@@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
345347
method_table = get_method_table_view(world, method_table)
346348
inf_cache = Vector{CC.InferenceResult}()
347349

348-
return GPUInterpreter(world, method_table,
350+
return GPUInterpreter(meta, world, method_table,
349351
token, inf_cache,
350352
inf_params, opt_params)
351353
end
352354

353355
function GPUInterpreter(interp::GPUInterpreter;
356+
meta=interp.meta,
354357
world::UInt=interp.world,
355358
method_table::GPUMethodTableView=interp.method_table,
356359
token::Any=interp.token,
357360
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
358361
inf_params::CC.InferenceParams=interp.inf_params,
359362
opt_params::CC.OptimizationParams=interp.opt_params)
360-
return GPUInterpreter(world, method_table,
363+
return GPUInterpreter(meta, world, method_table,
361364
token, inf_cache,
362365
inf_params, opt_params)
363366
end
364367

365368
else
366369

367370
function GPUInterpreter(world::UInt=Base.get_world_counter();
371+
meta=nothing,
368372
method_table::MTType,
369373
code_cache::CodeCache,
370374
inf_params::CC.InferenceParams,
@@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
374378
method_table = get_method_table_view(world, method_table)
375379
inf_cache = Vector{CC.InferenceResult}()
376380

377-
return GPUInterpreter(world, method_table,
381+
return GPUInterpreter(meta, world, method_table,
378382
code_cache, inf_cache,
379383
inf_params, opt_params)
380384
end
381385

382386
function GPUInterpreter(interp::GPUInterpreter;
387+
meta=interp.meta,
383388
world::UInt=interp.world,
384389
method_table::GPUMethodTableView=interp.method_table,
385390
code_cache::CodeCache=interp.code_cache,
386391
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
387392
inf_params::CC.InferenceParams=interp.inf_params,
388393
opt_params::CC.OptimizationParams=interp.opt_params)
389-
return GPUInterpreter(world, method_table,
394+
return GPUInterpreter(meta, world, method_table,
390395
code_cache, inf_cache,
391396
inf_params, opt_params)
392397
end
@@ -437,6 +442,8 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
437442
end
438443

439444

445+
within_gpucompiler() = false
446+
440447
## deferred compilation
441448

442449
struct DeferredCallInfo <: CC.CallInfo
@@ -445,9 +452,11 @@ struct DeferredCallInfo <: CC.CallInfo
445452
end
446453

447454
# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
448-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
449-
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
450-
max_methods::Int = CC.get_max_methods(interp, f, sv))
455+
# default implementation, extensible through meta argument.
456+
# XXX: (or should we dispatch on `f`)?
457+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
458+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
459+
max_methods::Int = CC.get_max_methods(interp, f, sv))
451460
(; fargs, argtypes) = arginfo
452461
if f === var"gpuc.deferred"
453462
argvec = argtypes[2:end]
@@ -458,7 +467,34 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
458467
else
459468
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
460469
end
470+
else f === within_gpucompiler
471+
if length(argtypes) != 1
472+
@static if VERSION < v"1.11.0-"
473+
return CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
474+
else
475+
return CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
476+
end
477+
end
478+
@static if VERSION < v"1.11.0-"
479+
return CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure())
480+
else
481+
return CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),)
482+
end
461483
end
484+
return nothing
485+
end
486+
487+
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
488+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
489+
max_methods::Int = CC.get_max_methods(interp, f, sv))
490+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
491+
if candidate === nothing && interp.meta !== nothing
492+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
493+
end
494+
if candidate !== nothing
495+
return candidate
496+
end
497+
462498
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
463499
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
464500
max_methods::Int)
@@ -542,6 +578,101 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
542578
end
543579
end
544580

581+
import .CC: CallInfo
582+
struct NoInlineCallInfo <: CallInfo
583+
info::CallInfo # wrapped call
584+
tt::Any # ::Type
585+
kind::Symbol
586+
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
587+
new(info, tt, kind)
588+
end
589+
590+
CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
591+
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
592+
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
593+
struct AlwaysInlineCallInfo <: CallInfo
594+
info::CallInfo # wrapped call
595+
tt::Any # ::Type
596+
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
597+
end
598+
599+
CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
600+
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
601+
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
602+
603+
604+
function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo)
605+
return nothing
606+
end
607+
608+
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
609+
function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo,
610+
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
611+
ret = @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, arginfo::ArgInfo,
612+
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)
613+
614+
callinfo = nothing
615+
if interp.meta !== nothing
616+
callinfo = inlining_handler(interp.meta, interp, atype, ret.callinfo)
617+
end
618+
if callinfo === nothing
619+
callinfo = inlining_handler(nothing, interp, atype, ret.callinfo)
620+
end
621+
if callinfo === nothing
622+
callinfo = ret.callinfo
623+
end
624+
625+
@static if VERSION v"1.11-"
626+
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
627+
else
628+
return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo)
629+
end
630+
end
631+
632+
let # overload `inlining_policy`
633+
@static if VERSION v"1.11.0-DEV.879"
634+
sigs_ex = :(
635+
interp::GPUInterpreter,
636+
@nospecialize(src),
637+
@nospecialize(info::Core.Compiler.CallInfo),
638+
stmt_flag::UInt32,
639+
)
640+
args_ex = :(
641+
interp::AbstractInterpreter,
642+
src::Any,
643+
info::Core.Compiler.CallInfo,
644+
stmt_flag::UInt32,
645+
)
646+
else
647+
sigs_ex = :(
648+
interp::GPUInterpreter,
649+
@nospecialize(src),
650+
@nospecialize(info::Core.Compiler.CallInfo),
651+
stmt_flag::UInt8,
652+
mi::MethodInstance,
653+
argtypes::Vector{Any},
654+
)
655+
args_ex = :(
656+
interp::AbstractInterpreter,
657+
src::Any,
658+
info::Core.Compiler.CallInfo,
659+
stmt_flag::UInt8,
660+
mi::MethodInstance,
661+
argtypes::Vector{Any},
662+
)
663+
end
664+
@eval function Core.Compiler.inlining_policy($(sigs_ex.args...))
665+
if info isa NoInlineCallInfo
666+
@safe_debug "Blocking inlining" info.tt info.kind
667+
return nothing
668+
elseif info isa AlwaysInlineCallInfo
669+
@safe_debug "Forcing inlining for" info.tt
670+
return src
671+
end
672+
return @invoke Core.Compiler.inlining_policy($(args_ex.args...))
673+
end
674+
end
675+
545676

546677
## world view of the cache
547678
using Core.Compiler: WorldView
@@ -704,7 +835,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
704835
source = pop!(worklist)
705836
haskey(compiled, source) && continue # We have fulfilled the request already
706837
# Create a new compiler job for this edge, reusing the config settings from the inital one
707-
job2 = CompilerJob(source, job.config)
838+
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
708839
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
709840
append!(worklist, outstanding) # merge worklist with new outstanding edges
710841
@assert context(llvm_mod) == context(llvm_mod2)

test/ptx_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,16 @@ end
288288
@test !occursin("gpucompiler.mark", ir)
289289
end
290290

291+
@testset "within_gpucompiler" begin
292+
function kernel(a)
293+
unsafe_store!(a, GPUCompiler.within_gpucompiler())
294+
end
295+
ir = sprint(io, code_llvm(io, kernel, Tuple{Int}))
296+
@show
297+
end
298+
299+
300+
291301
@testset "exception arguments" begin
292302
function kernel(a)
293303
unsafe_store!(a, trunc(Int, unsafe_load(a)))

0 commit comments

Comments
 (0)