diff --git a/Project.toml b/Project.toml index ff850fa686..d02d867aa0 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -33,7 +34,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" -GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" @@ -103,9 +103,9 @@ Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" -unzip_jll = "6" YaoBlocks = "0.13, 0.14" julia = "1.10" +unzip_jll = "6" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 4cdb4e6d89..13891af659 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1495,7 +1495,6 @@ end Reactant.XLA.free_client(client) client.client = C_NULL Reactant.deinitialize_dialect() - Reactant.clear_oc_cache() end end diff --git a/src/Precompile.jl b/src/Precompile.jl index 33cb346185..1d064342a4 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -35,19 +35,6 @@ function infer_sig(sig) end end -function clear_oc_cache() - # Opaque closures capture the worldage of their compilation and thus are not relocatable - # Therefore we explicitly purge all OC's we have created here - for v in oc_capture_vec - if v isa Base.RefValue - p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) - Base.atomic_pointerset(p, C_NULL, :monotonic) - else - empty!(v) - end - end -end - # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 function precompilation_supported() return VERSION >= v"1.11" || VERSION >= v"1.10.8" @@ -78,6 +65,5 @@ if Reactant_jll.is_available() XLA.free_client(client) client.client = C_NULL deinitialize_dialect() - clear_oc_cache() end end diff --git a/src/utils.jl b/src/utils.jl index bed592af38..7a89060b09 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -150,7 +150,7 @@ will need a `@reactant_overlay` method. !!! warning The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope + mark it for precompilation, you must add the macro call in the global scope too. See also: [`@skip_rewrite_type`](@ref) @@ -189,7 +189,7 @@ abstract type, you should use then the `Type{<:MyStruct}` syntax. !!! warning The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope + mark it for precompilation, you must add the macro call in the global scope too. """ macro skip_rewrite_type(typ) @@ -316,7 +316,7 @@ function certain_error() ) end -function rewrite_inst(inst, ir, interp, RT, guaranteed_error) +function rewrite_inst(inst, ir::CC.IRCode, interp, RT, guaranteed_error) if Meta.isexpr(inst, :call) # Even if type unstable we do not want (or need) to replace intrinsic # calls or builtins with our version. @@ -449,78 +449,6 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) return false, inst, RT end -const oc_capture_vec = Vector{Any}() - -# Caching is both good to reducing compile times and necessary to work around julia bugs -# in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 -function make_oc_dict( - @nospecialize(oc_captures::Dict{FT,Core.OpaqueClosure}), - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f::FT) -)::Core.OpaqueClosure where {FT} - key = f - if haskey(oc_captures, key) - oc = oc_captures[key] - oc - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[key] = ores - return ores - end -end - -function make_oc_ref( - oc_captures::Base.RefValue{Core.OpaqueClosure}, - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f) -)::Core.OpaqueClosure - if Base.isassigned(oc_captures) - return oc_captures[] - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[] = ores - return ores - end -end - function safe_print(name, x) return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) end @@ -532,7 +460,7 @@ const DEBUG_INTERP = Ref(false) # to Any if our interpreter would change the return type of any result. # Also rewrite invoke (type stable call) to be :call, since otherwise apparently # screws up type inference after this (TODO this should be fixed). -function rewrite_insts!(ir, interp, guaranteed_error) +function rewrite_insts!(ir::CC.IRCode, interp, guaranteed_error) any_changed = false for (i, inst) in enumerate(ir.stmts) # Explicitly skip any code which returns Union{} so that we throw the error @@ -553,35 +481,80 @@ function rewrite_insts!(ir, interp, guaranteed_error) return ir, any_changed end -function rewrite_argnumbers_by_one!(ir) - # Add one dummy argument at the beginning - pushfirst!(ir.argtypes, Nothing) - - # Re-write all references to existing arguments to their new index (N + 1) - for idx in 1:length(ir.stmts) - urs = Core.Compiler.userefs(ir.stmts[idx][:inst]) - changed = false - it = Core.Compiler.iterate(urs) - while it !== nothing - (ur, next) = it - old = Core.Compiler.getindex(ur) - if old isa Core.Argument - # Replace the Argument(n) with Argument(n + 1) - Core.Compiler.setindex!(ur, Core.Argument(old.n + 1)) - changed = true - end - it = Core.Compiler.iterate(urs, next) - end - if changed - @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :inst) - else - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :stmt) - end - end +using GPUCompiler +using GPUCompiler: AbstractCompilerParams, CompilerJob, NativeCompilerTarget + +struct CompilerParams <: AbstractCompilerParams + function CompilerParams() + return new() + end +end + +NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams} + +GPUCompiler.can_safepoint(@nospecialize(job::NativeCompilerJob)) = false +GPUCompiler.can_throw(@nospecialize(job::NativeCompilerJob)) = true +GPUCompiler.needs_byval(@nospecialize(job::NativeCompilerJob)) = false + +function GPUCompiler.optimize!( + @nospecialize(job::NativeCompilerJob), mod::GPUCompiler.LLVM.Module; opt_level +) + return nothing #TODO: add all except GPU stuff passes +end + +using Enzyme +ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{ + typeof(Reactant.set_reactant_abi) +} + +function GPUCompiler.get_interpreter(@nospecialize(job::NativeCompilerJob)) + return Reactant.ReactantInterpreter(; world=job.world) +end +function GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) + return CC.method_table(GPUCompiler.get_interpreter(job)) +end + +function GPUCompiler.llvm_debug_info(@nospecialize(::NativeCompilerJob)) + return GPUCompiler.LLVM.API.LLVMDebugEmissionKindNoDebug +end + +function CC.optimize( + interp::ReactantInter, opt::CC.OptimizationState, caller::CC.InferenceResult +) + CC.@timeit "optimizer" ir = CC.run_passes_ipo_safe(opt.src, opt, caller) + CC.ipo_dataflow_analysis!(interp, ir, caller) + + mi = caller.linfo + if false && + mi in mi_set && + !( + is_reactant_method(mi) || ( + mi.def.sig isa DataType && + !should_rewrite_invoke( + mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...} + ) + ) + ) + @info ir + ir, has_changed = rewrite_insts!(ir, interp, false) + @info ir + has_changed && @info "rewrite instruction $mi" end - return nothing + return CC.finish(interp, opt, ir, caller) +end + +function GPUCompiler.ci_cache_populate( + interp::Reactant.ReactantInter, + cache::CC.WorldView{CC.InternalCodeCache}, + mi::Core.MethodInstance, + min_world::UInt64, + max_world::UInt64, +) + @warn mi min_world max_world CC.get_inference_world(interp) + @invoke GPUCompiler.ci_cache_populate( + interp::CC.AbstractInterpreter, cache, mi, min_world, max_world + ) end # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter @@ -591,89 +564,33 @@ end # replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia # using a custom interpreter in type unstable code. # `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` + function call_with_reactant_generator( world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) ) @nospecialize args = redub_arguments - if DEBUG_INTERP[] - safe_print("args", args) - end stub = Core.GeneratedFunctionStub( - identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() + identity, Core.svec(:call_with_reactant, Reactant.REDUB_ARGUMENTS_NAME), Core.svec() ) - fn = args[1] - sig = Tuple{args...} - guaranteed_error = false - if fn === MustThrowError + if args[1] === Reactant.MustThrowError guaranteed_error = true - fn = args[2] - sig = Tuple{args[2:end]...} end - - # look up the method match - builtin_error = - :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) + offset_error = guaranteed_error ? 1 : 0 + fn = args[1 + offset_error] if fn <: Core.Builtin + builtin_error = + :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) return stub(world, source, builtin_error) end - if guaranteed_error - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world) - )) - else - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) - )) - end - - interp = ReactantInterpreter(; world) - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - lookup_result = lookup_world( - sig, world, Core.Compiler.method_table(interp), min_world, max_world - ) - overdubbed_code = Any[] overdubbed_codelocs = Int32[] - # No method could be found (including in our method table), bail with an error - if lookup_result === nothing - return stub(world, source, method_error) - end - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - method = mi.def - - @static if VERSION < v"1.11" - # For older Julia versions, we vendor in some of the code to prevent - # having to build the MethodInstance twice. - result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) - frame = CC.InferenceState(result, :no, interp) - @assert !isnothing(frame) - CC.typeinf(interp, frame) - ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) - rt = CC.widenconst(CC.ignorelimited(result.result)) - else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) - end - if guaranteed_error if rt !== Union{} safe_print("Inconsistent guaranteed error IR", ir) @@ -681,50 +598,60 @@ function call_with_reactant_generator( rt = Union{} end - if DEBUG_INTERP[] - safe_print("ir", ir) - end + mi = GPUCompiler.methodinstance( + fn, Base.to_tuple_type(args[(2 + offset_error):end]), world + ) + if mi === nothing + method_error = :(throw( + MethodError( + $REDUB_ARGUMENTS_NAME[1 + offset_error], + $REDUB_ARGUMENTS_NAME[(2 + offset_error):end], + $world, + ), + )) + return stub(world, mi, method_error) + end + config = CompilerConfig( + Reactant.NativeCompilerTarget(), + Reactant.CompilerParams(); + kernel=false, + libraries=false, + toplevel=true, + validate=false, + strip=true, + entry_abi=:func, + ) - mi = mi::Core.MethodInstance + job = GPUCompiler.CompilerJob(mi, config, world) - if !( - is_reactant_method(mi) || ( - mi.def.sig isa DataType && - !should_rewrite_invoke( - mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...} - ) - ) - ) || guaranteed_error - ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error) + llvm_module, meta_ = Reactant.JuliaContext() do _ctx + GPUCompiler.compile(:llvm, job) end + mm = meta_.compiled[job.source] - rewrite_argnumbers_by_one!(ir) - - src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) - src.slotnames = fill(:none, length(ir.argtypes) + 1) - src.slotflags = fill(zero(UInt8), length(ir.argtypes)) - src.slottypes = copy(ir.argtypes) - src.rettype = rt - src = CC.ir_to_codeinf!(src, ir) - - if DEBUG_INTERP[] - safe_print("src", src) + #CodeInfo placehold + code_info = begin + ir = CC.IRCode() + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = Int + CC.ir_to_codeinf!(src, ir) end - # prepare a new code info - code_info = copy(src) - static_params = match.sparams - signature = sig + rt = mm.ci.rettype + code_info.rettype = rt # propagate edge metadata, this method is invalidated if the original function we are calling # is invalidated - code_info.edges = Core.MethodInstance[mi] - code_info.min_world = min_world[] - code_info.max_world = max_world[] + code_info.edges = Core.MethodInstance[job.source] + code_info.min_world = typemin(UInt) + code_info.max_world = typemax(UInt) # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] + code_info.slotnames = Any[:call_with_reactant, Reactant.REDUB_ARGUMENTS_NAME] code_info.slotflags = UInt8[0x00, 0x00] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -744,115 +671,43 @@ function call_with_reactant_generator( # destructure the generated argument slots into the overdubbed method's argument slots. - offset = 1 + offset = 2 fn_args = Any[] - n_method_args = method.nargs + method = job.source.def n_actual_args = length(redub_arguments) - if guaranteed_error - offset += 1 - n_actual_args -= 1 - end + offset += offset_error + n_actual_args -= offset_error tys = [] - - iter_args = n_actual_args - if method.isva - iter_args = min(n_actual_args, n_method_args - 1) - end - - for i in 1:iter_args + for i in 2:n_actual_args + type = redub_arguments[i + (guaranteed_error ? 1 : 0)] actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) arg = push_inst!(actual_argument) - offset += 1 push!(fn_args, arg) - push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # If `method` is a varargs method, we have to restructure the original method call's - # trailing arguments into a tuple and assign that tuple to the expected argument slot. - if method.isva - trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) - for i in n_method_args:n_actual_args - arg = push_inst!( - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) - ) - push!(trailing_arguments.args, arg) - offset += 1 - end - - push!(fn_args, push_inst!(trailing_arguments)) - push!( - tys, - Tuple{ - redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]..., - }, - ) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # ocva = method.isva - - ocva = false # method.isva - - ocnargs = Int(method.nargs) - # octup = Tuple{mi.specTypes.parameters[2:end]...} - # octup = Tuple{method.sig.parameters[2:end]...} - octup = Tuple{tys[1:end]...} - ocva = false - - # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right - # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require taking the function argument. We can work around the latter - # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - - dict, make_oc = (Base.Ref{Core.OpaqueClosure}(), make_oc_ref) - - push!(oc_capture_vec, dict) - - oc = if false && Base.issingletontype(fn) - res = Core._call_in_world_total( - world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance - )::Core.OpaqueClosure - else - farg = fn_args[1] - farg = nothing - rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) - push_inst!(rep) - Core.SSAValue(length(overdubbed_code)) - end - - push_inst!(Expr(:call, oc, fn_args[1:end]...)) - - ocres = Core.SSAValue(length(overdubbed_code)) - - if DEBUG_INTERP[] - push_inst!(Expr(:call, safe_print, "ocres", ocres)) + push!(tys, Base.RefValue{type}) + offset += 1 end - push_inst!(Core.ReturnNode(ocres)) + #force the creation of Any[fn_args...] + fn_args_vec = push_inst!( + Expr(:call, GlobalRef(Base, :getindex), GlobalRef(Base, :Any), fn_args...) + ) + pointer = push_inst!(Expr(:call, GlobalRef(Base, :pointer), fn_args_vec)) + boxed_res = push_inst!( + Expr( + :call, + GlobalRef(Base, :llvmcall), + (string(llvm_module), mm.func), + Any, + Tuple{Base.RefValue{fn},Ptr{Any},Int32}, + fn, + pointer, + Int32(length(fn_args)), + ), + ) + push_inst!(Core.ReturnNode(boxed_res)) #=== set `code_info`/`reflection` fields accordingly ===# @@ -864,15 +719,10 @@ function call_with_reactant_generator( code_info.codelocs = overdubbed_codelocs code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - - if DEBUG_INTERP[] - safe_print("code_info", code_info) - end - return code_info end -@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) +@eval function call_with_reactant($(Reactant.REDUB_ARGUMENTS_NAME)...) $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end @@ -882,4 +732,4 @@ end end nmantissa(::Type{Float16}) = 10 nmantissa(::Type{Float32}) = 23 -nmantissa(::Type{Float64}) = 52 +nmantissa(::Type{Float64}) = 52 \ No newline at end of file