Skip to content

Commit 6f8532b

Browse files
committed
Make GPUInterpreter extensible
Currently Enzyme uses it's own AbstractInterpreter, in particular to handle inlining blocking of functions with custom rules and to handle nested autodiff operations. - [ ] Create a version of Enzyme with this - [ ] Support a version of `gpuc.deferred(meta)`
1 parent e9d1372 commit 6f8532b

File tree

3 files changed

+164
-10
lines changed

3 files changed

+164
-10
lines changed

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
181181
# provide a specific interpreter to use.
182182
if VERSION >= v"1.11.0-DEV.1552"
183183
get_interpreter(@nospecialize(job::CompilerJob)) =
184-
GPUInterpreter(job.world; method_table=method_table(job),
184+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
185185
token=ci_cache_token(job), inf_params=inference_params(job),
186186
opt_params=optimization_params(job))
187187
else
188188
get_interpreter(@nospecialize(job::CompilerJob)) =
189-
GPUInterpreter(job.world; method_table=method_table(job),
189+
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
190190
code_cache=ci_cache(job), inf_params=inference_params(job),
191191
opt_params=optimization_params(job))
192192
end

src/jlgen.jl

Lines changed: 154 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ else
320320
end
321321

322322
struct GPUInterpreter <: CC.AbstractInterpreter
323+
meta::Any
323324
world::UInt
324325
method_table::GPUMethodTableView
325326

@@ -336,6 +337,7 @@ end
336337

337338
@static if HAS_INTEGRATED_CACHE
338339
function GPUInterpreter(world::UInt=Base.get_world_counter();
340+
meta = nothing,
339341
method_table::MTType,
340342
token::Any,
341343
inf_params::CC.InferenceParams,
@@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
345347
method_table = get_method_table_view(world, method_table)
346348
inf_cache = Vector{CC.InferenceResult}()
347349

348-
return GPUInterpreter(world, method_table,
350+
return GPUInterpreter(meta, world, method_table,
349351
token, inf_cache,
350352
inf_params, opt_params)
351353
end
352354

353355
function GPUInterpreter(interp::GPUInterpreter;
356+
meta=interp.meta,
354357
world::UInt=interp.world,
355358
method_table::GPUMethodTableView=interp.method_table,
356359
token::Any=interp.token,
357360
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
358361
inf_params::CC.InferenceParams=interp.inf_params,
359362
opt_params::CC.OptimizationParams=interp.opt_params)
360-
return GPUInterpreter(world, method_table,
363+
return GPUInterpreter(meta, world, method_table,
361364
token, inf_cache,
362365
inf_params, opt_params)
363366
end
364367

365368
else
366369

367370
function GPUInterpreter(world::UInt=Base.get_world_counter();
371+
meta=nothing,
368372
method_table::MTType,
369373
code_cache::CodeCache,
370374
inf_params::CC.InferenceParams,
@@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
374378
method_table = get_method_table_view(world, method_table)
375379
inf_cache = Vector{CC.InferenceResult}()
376380

377-
return GPUInterpreter(world, method_table,
381+
return GPUInterpreter(meta, world, method_table,
378382
code_cache, inf_cache,
379383
inf_params, opt_params)
380384
end
381385

382386
function GPUInterpreter(interp::GPUInterpreter;
387+
meta=interp.meta,
383388
world::UInt=interp.world,
384389
method_table::GPUMethodTableView=interp.method_table,
385390
code_cache::CodeCache=interp.code_cache,
386391
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
387392
inf_params::CC.InferenceParams=interp.inf_params,
388393
opt_params::CC.OptimizationParams=interp.opt_params)
389-
return GPUInterpreter(world, method_table,
394+
return GPUInterpreter(meta, world, method_table,
390395
code_cache, inf_cache,
391396
inf_params, opt_params)
392397
end
@@ -437,6 +442,8 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
437442
end
438443

439444

445+
within_gpucompiler() = false
446+
440447
## deferred compilation
441448

442449
struct DeferredCallInfo <: CC.CallInfo
@@ -445,9 +452,11 @@ struct DeferredCallInfo <: CC.CallInfo
445452
end
446453

447454
# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
448-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
449-
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
450-
max_methods::Int = CC.get_max_methods(interp, f, sv))
455+
# default implementation, extensible through meta argument.
456+
# XXX: (or should we dispatch on `f`)?
457+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
458+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
459+
max_methods::Int = CC.get_max_methods(interp, f, sv))
451460
(; fargs, argtypes) = arginfo
452461
if f === var"gpuc.deferred"
453462
argvec = argtypes[2:end]
@@ -458,7 +467,34 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
458467
else
459468
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
460469
end
470+
elseif f === within_gpucompiler
471+
if length(argtypes) != 1
472+
@static if VERSION < v"1.11.0-"
473+
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
474+
else
475+
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
476+
end
477+
end
478+
@static if VERSION < v"1.11.0-"
479+
return CC.CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure())
480+
else
481+
return CC.CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),)
482+
end
461483
end
484+
return nothing
485+
end
486+
487+
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
488+
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
489+
max_methods::Int = CC.get_max_methods(interp, f, sv))
490+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
491+
if candidate === nothing && interp.meta !== nothing
492+
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
493+
end
494+
if candidate !== nothing
495+
return candidate
496+
end
497+
462498
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
463499
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
464500
max_methods::Int)
@@ -542,6 +578,116 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
542578
end
543579
end
544580

581+
import .CC: CallInfo
582+
struct NoInlineCallInfo <: CallInfo
583+
info::CallInfo # wrapped call
584+
tt::Any # ::Type
585+
kind::Symbol
586+
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
587+
new(info, tt, kind)
588+
end
589+
590+
CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
591+
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
592+
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
593+
struct AlwaysInlineCallInfo <: CallInfo
594+
info::CallInfo # wrapped call
595+
tt::Any # ::Type
596+
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
597+
end
598+
599+
CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
600+
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
601+
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
602+
603+
604+
function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo)
605+
return nothing
606+
end
607+
608+
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
609+
function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo,
610+
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
611+
ret = @invoke CC.abstract_call_gf_by_type(interp::CC.AbstractInterpreter, f::Any, arginfo::ArgInfo,
612+
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)
613+
614+
callinfo = nothing
615+
if interp.meta !== nothing
616+
callinfo = inlining_handler(interp.meta, interp, atype, ret.info)
617+
end
618+
if callinfo === nothing
619+
callinfo = inlining_handler(nothing, interp, atype, ret.info)
620+
end
621+
if callinfo === nothing
622+
callinfo = ret.info
623+
end
624+
625+
@static if VERSION v"1.11-"
626+
return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
627+
else
628+
return CC.CallMeta(ret.rt, ret.effects, callinfo)
629+
end
630+
end
631+
632+
@static if VERSION < v"1.12.0-DEV.45"
633+
let # overload `inlining_policy`
634+
@static if VERSION v"1.11.0-DEV.879"
635+
sigs_ex = :(
636+
interp::GPUInterpreter,
637+
@nospecialize(src),
638+
@nospecialize(info::CC.CallInfo),
639+
stmt_flag::UInt32,
640+
)
641+
args_ex = :(
642+
interp::CC.AbstractInterpreter,
643+
src::Any,
644+
info::CC.CallInfo,
645+
stmt_flag::UInt32,
646+
)
647+
else
648+
sigs_ex = :(
649+
interp::GPUInterpreter,
650+
@nospecialize(src),
651+
@nospecialize(info::CC.CallInfo),
652+
stmt_flag::UInt8,
653+
mi::MethodInstance,
654+
argtypes::Vector{Any},
655+
)
656+
args_ex = :(
657+
interp::CC.AbstractInterpreter,
658+
src::Any,
659+
info::CC.CallInfo,
660+
stmt_flag::UInt8,
661+
mi::MethodInstance,
662+
argtypes::Vector{Any},
663+
)
664+
end
665+
@eval function CC.inlining_policy($(sigs_ex.args...))
666+
if info isa NoInlineCallInfo
667+
@safe_debug "Blocking inlining" info.tt info.kind
668+
return nothing
669+
elseif info isa AlwaysInlineCallInfo
670+
@safe_debug "Forcing inlining for" info.tt
671+
return src
672+
end
673+
return @invoke CC.inlining_policy($(args_ex.args...))
674+
end
675+
end
676+
else
677+
function CC.src_inlining_policy(interp::GPUInterpreter,
678+
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
679+
680+
if info isa NoInlineCallInfo
681+
@safe_debug "Blocking inlining" info.tt info.kind
682+
return false
683+
elseif info isa AlwaysInlineCallInfo
684+
@safe_debug "Forcing inlining for" info.tt
685+
return true
686+
end
687+
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src, info::CC.CallInfo, stmt_flag::UInt32)
688+
end
689+
end
690+
545691

546692
## world view of the cache
547693
using Core.Compiler: WorldView
@@ -704,7 +850,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
704850
source = pop!(worklist)
705851
haskey(compiled, source) && continue # We have fulfilled the request already
706852
# Create a new compiler job for this edge, reusing the config settings from the inital one
707-
job2 = CompilerJob(source, job.config)
853+
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
708854
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
709855
append!(worklist, outstanding) # merge worklist with new outstanding edges
710856
@assert context(llvm_mod) == context(llvm_mod2)

test/ptx_tests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,14 @@ end
288288
@test !occursin("gpucompiler.mark", ir)
289289
end
290290

291+
@testset "within_gpucompiler" begin
292+
function kernel(a)
293+
unsafe_store!(a, GPUCompiler.within_gpucompiler())
294+
end
295+
ir = sprint(io->code_llvm(io, kernel, Tuple{Int}))
296+
@show ir
297+
end
298+
291299
@testset "exception arguments" begin
292300
function kernel(a)
293301
unsafe_store!(a, trunc(Int, unsafe_load(a)))

0 commit comments

Comments
 (0)