Skip to content

Commit ec369e5

Browse files
committed
replace OpaqueClosures with a GPUCompiler compiler
1 parent 871b1ed commit ec369e5

File tree

2 files changed

+112
-4
lines changed

2 files changed

+112
-4
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1212
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
15+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
1516
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
17+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1618
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
1719
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1820
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/utils.jl

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ function certain_error()
316316
)
317317
end
318318

319-
function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
319+
function rewrite_inst(inst, ir::CC.IRCode, interp, RT, guaranteed_error)
320320
if Meta.isexpr(inst, :call)
321321
# Even if type unstable we do not want (or need) to replace intrinsic
322322
# calls or builtins with our version.
@@ -532,7 +532,7 @@ const DEBUG_INTERP = Ref(false)
532532
# to Any if our interpreter would change the return type of any result.
533533
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
534534
# screws up type inference after this (TODO this should be fixed).
535-
function rewrite_insts!(ir, interp, guaranteed_error)
535+
function rewrite_insts!(ir::CC.IRCode, interp, guaranteed_error)
536536
any_changed = false
537537
for (i, inst) in enumerate(ir.stmts)
538538
# Explicitly skip any code which returns Union{} so that we throw the error
@@ -839,11 +839,11 @@ function call_with_reactant_generator(
839839
if DEBUG_INTERP[]
840840
safe_print("code_info", code_info)
841841
end
842-
842+
#@lk code_info oc
843843
return code_info
844844
end
845845

846-
@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...)
846+
@eval function call_with_reactant0($REDUB_ARGUMENTS_NAME...)
847847
$(Expr(:meta, :generated_only))
848848
return $(Expr(:meta, :generated, call_with_reactant_generator))
849849
end
@@ -854,3 +854,109 @@ end
854854
nmantissa(::Type{Float16}) = 10
855855
nmantissa(::Type{Float32}) = 23
856856
nmantissa(::Type{Float64}) = 52
857+
858+
using GPUCompiler
859+
using GPUCompiler: AbstractCompilerParams, CompilerJob, NativeCompilerTarget
860+
861+
Base.Experimental.@MethodTable(test_method_table)
862+
863+
struct CompilerParams <: AbstractCompilerParams
864+
entry_safepoint::Bool
865+
method_table
866+
867+
function CompilerParams(entry_safepoint::Bool=false, method_table=test_method_table)
868+
return new(entry_safepoint, method_table)
869+
end
870+
end
871+
872+
NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams}
873+
874+
function GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob))
875+
return job.config.params.method_table
876+
end
877+
function GPUCompiler.can_safepoint(@nospecialize(job::NativeCompilerJob))
878+
return job.config.params.entry_safepoint
879+
end
880+
881+
GPUCompiler.can_throw(@nospecialize(job::NativeCompilerJob)) = true
882+
GPUCompiler.needs_byval(@nospecialize(job::NativeCompilerJob)) = false
883+
884+
function GPUCompiler.optimize!(
885+
@nospecialize(job::NativeCompilerJob), mod::GPUCompiler.LLVM.Module; opt_level
886+
)
887+
return nothing #TODO: add all except GPU stuff passes
888+
end
889+
890+
function create_job(
891+
@nospecialize(func),
892+
@nospecialize(types);
893+
entry_safepoint::Bool=false,
894+
method_table=test_method_table,
895+
kwargs...,
896+
)
897+
config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS)
898+
source = methodinstance(
899+
typeof(func), Base.to_tuple_type(types), Base.get_world_counter()
900+
)
901+
target = NativeCompilerTarget()
902+
params = CompilerParams(entry_safepoint, method_table)
903+
config = CompilerConfig(
904+
target, params; kernel=false, libraries=false, toplevel=true, config_kwargs...
905+
)
906+
return CompilerJob(source, config), kwargs
907+
end
908+
909+
using Enzyme
910+
ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{
911+
typeof(Reactant.set_reactant_abi)
912+
}
913+
914+
GPUCompiler.get_interpreter(::NativeCompilerJob) = Reactant.ReactantInterpreter()
915+
916+
917+
function CC.optimize(
918+
interp::ReactantInter, opt::CC.OptimizationState, caller::CC.InferenceResult
919+
)
920+
CC.@timeit "optimizer" ir = CC.run_passes_ipo_safe(opt.src, opt, caller)
921+
CC.ipo_dataflow_analysis!(interp, ir, caller)
922+
923+
mi = caller.linfo
924+
if false && !(
925+
is_reactant_method(mi) || (
926+
mi.def.sig isa DataType &&
927+
!should_rewrite_invoke(
928+
mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...}
929+
)
930+
)
931+
)
932+
@info ir
933+
ir, has_changed = rewrite_insts!(ir, interp, false)
934+
@info ir
935+
has_changed && @info "rewrite instruction $mi"
936+
end
937+
938+
939+
return CC.finish(interp, opt, ir, caller)
940+
end
941+
942+
function call_with_reactant(@nospecialize(args...))
943+
f = args[1]
944+
types = typeof.(args[2:end])
945+
946+
job, meta = Reactant.create_job(f, types; validate=false)
947+
llvm_module, meta_ = Reactant.JuliaContext() do ctx
948+
GPUCompiler.compile(:llvm, job)
949+
end
950+
mm = meta_.compiled[job.source]
951+
@error mm.ci.def types args
952+
expr = Expr(
953+
:call,
954+
GlobalRef(Base, :llvmcall),
955+
(string(llvm_module), mm.specfunc),
956+
mm.ci.rettype,
957+
Tuple{types...},
958+
args[2:end]...,
959+
)
960+
#TODO: replace with a generated function
961+
@eval $expr
962+
end

0 commit comments

Comments
 (0)