diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 6ffcec0ab3..047815d83d 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -4,107 +4,3 @@ # - https://github.com/JuliaLang/julia/blob/v1.10.4/test/compiler/newinterp.jl#L9 const CC = Core.Compiler -using Enzyme - -import Core.Compiler: - AbstractInterpreter, - abstract_call, - abstract_call_known, - ArgInfo, - StmtInfo, - AbsIntState, - get_max_methods, - CallMeta, - Effects, - NoCallInfo, - MethodResultPure - -Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) - -function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) - return Base.Experimental.var"@overlay"( - __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def - ) -end - -function set_reactant_abi( - interp, - @nospecialize(f), - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int=get_max_methods(interp, f, sv), -) - (; fargs, argtypes) = arginfo - - if f === ReactantCore.within_compile - if length(argtypes) != 1 - @static if VERSION < v"1.11.0-" - return CallMeta(Union{}, Effects(), NoCallInfo()) - else - return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) - end - end - @static if VERSION < v"1.11.0-" - return CallMeta( - Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() - ) - else - return CallMeta( - Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() - ) - end - end - - # Improve inference by considering call_with_reactant as having the same results as - # the original call - if f === call_with_reactant - arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) - return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) - end - - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - f::Any, - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) -end - -@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE - struct ReactantCacheToken end - - function ReactantInterpreter(; world::UInt=Base.get_world_counter()) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( - ReactantCacheToken(), - REACTANT_METHOD_TABLE, - world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, - ) - end -else - const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache() - - function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE - ) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( - REACTANT_CACHE, - REACTANT_METHOD_TABLE, - world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, - ) - end -end diff --git a/src/JIT.jl b/src/JIT.jl new file mode 100644 index 0000000000..6cf144dd9b --- /dev/null +++ b/src/JIT.jl @@ -0,0 +1,692 @@ +using GPUCompiler +CC = Core.Compiler + +#leak each argument to a global variable +macro lk(args...) + quote + $([:( + let val = $(esc(p)) + global $(esc(p)) = val + end + ) for p in args]...) + end +end + +Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) + +function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) + return Base.Experimental.var"@overlay"( + __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def + ) +end + +function call_with_reactant() end + +@noinline call_with_native(@nospecialize(f), @nospecialize(args...)) = + Base.inferencebarrier(f)(args...) + +const __skip_rewrite_func_set = Set([ + typeof(call_with_reactant), + typeof(call_with_native), + typeof(task_local_storage), + typeof(getproperty), + typeof(invokelatest), +]) +const __skip_rewrite_func_set_lock = ReentrantLock() + +""" + @skip_rewrite_func f + +Mark function `f` so that Reactant's IR rewrite mechanism will skip it. +This can improve compilation time if it's safe to assume that no call inside `f` +will need a `@reactant_overlay` method. + +!!! info + Note that this marks the whole function, not a specific method with a type + signature. + +!!! warning + The macro call should be inside the `__init__` function. If you want to + mark it for precompilation, you must add the macro call in the global scope + too. + +See also: [`@skip_rewrite_type`](@ref) +""" +macro skip_rewrite_func(fname) + quote + @lock $(Reactant.__skip_rewrite_func_set_lock) push!( + $(Reactant.__skip_rewrite_func_set), typeof($(esc(fname))) + ) + end +end + +const __skip_files = Set([Symbol("sysimg.jl"), Symbol("boot.jl")]) + +struct CompilerParams <: AbstractCompilerParams + function CompilerParams() + return new() + end +end + +include("auto_cf/analysis.jl") + +MethodInstanceKey = Tuple{Vector{Type}} +function mi_key(mi::Core.MethodInstance) + return collect(Base.unwrap_unionall(mi.specTypes).parameters) +end +@kwdef struct MetaData + traced_tree_map::Dict{MethodInstanceKey,Tree} = Dict() +end + +@kwdef struct DebugData + enable_log::Bool = true + enable_runtime_log::Bool = true + rewrite_call::Set = Set() + non_rewrite_call::Set = Set() +end + +struct ReactantToken end + +@kwdef struct ReactantInterpreter <: CC.AbstractInterpreter + token::ReactantToken = ReactantToken() + # Cache of inference results for this particular interpreter + local_cache::Vector{CC.InferenceResult} = CC.InferenceResult[] + # The world age we're working inside of + world::UInt = Base.get_world_counter() + + # Parameters for inference and optimization + inf_params::CC.InferenceParams = CC.InferenceParams() + opt_params::CC.OptimizationParams = CC.OptimizationParams() + + meta_data::Ref{MetaData} = Ref(MetaData()) + debug_data::Ref{DebugData} = Ref(DebugData()) +end + +log(interp::ReactantInterpreter)::Bool = interp.debug_data[].enable_log +runtime_log(interp::ReactantInterpreter)::Bool = interp.debug_data[].enable_runtime_log +reset_debug_data(interp::ReactantInterpreter) = interp.debug_data[] = DebugData(); + +NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams} +GPUCompiler.can_throw(@nospecialize(job::NativeCompilerJob)) = true +function GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) + return CC.method_table(GPUCompiler.get_interpreter(job)) +end + +current_interpreter = Ref{Union{Nothing,ReactantInterpreter}}(nothing) + +function GPUCompiler.get_interpreter(@nospecialize(job::NativeCompilerJob)) + isnothing(current_interpreter[]) && + (return current_interpreter[] = ReactantInterpreter(; world=job.world)) + + if job.world == current_interpreter[].world + current_interpreter[] + else + (; meta_data, debug_data) = current_interpreter[] + current_interpreter[] = ReactantInterpreter(; + world=job.world, meta_data, debug_data + ) + end +end + +@noinline barrier(@nospecialize(x), @nospecialize(T::Type = Any)) = + Core.Compiler.inferencebarrier(x)::T + +CC.InferenceParams(@nospecialize(interp::ReactantInterpreter)) = interp.inf_params +CC.OptimizationParams(@nospecialize(interp::ReactantInterpreter)) = interp.opt_params +CC.get_inference_world(@nospecialize(interp::ReactantInterpreter)) = interp.world +CC.get_inference_cache(@nospecialize(interp::ReactantInterpreter)) = interp.local_cache +CC.cache_owner(@nospecialize(interp::ReactantInterpreter)) = interp.token +function CC.method_table(@nospecialize(interp::ReactantInterpreter)) + return CC.OverlayMethodTable(CC.get_inference_world(interp), REACTANT_METHOD_TABLE) +end + +function has_ancestor(query::Module, target::Module) + query == target && return true + while true + next = parentmodule(query) + next == target && return true + next == query && return false + query = next + end +end +is_base_or_core(t::TypeVar) = begin + println("TypeVar ", t) + return false +end +is_base_or_core(t::Core.TypeofVararg) = is_base_or_core(t.T) +is_base_or_core(m::Module) = has_ancestor(m, Core) || has_ancestor(m, Base) +is_base_or_core(@nospecialize(u::Union)) = begin + u == Union{} && return true + is_base_or_core(u.a) && is_base_or_core(u.b) +end +is_base_or_core(u::UnionAll) = is_base_or_core(Base.unwrap_unionall(u)) +is_base_or_core(@nospecialize(ty::Type)) = is_base_or_core(parentmodule(ty)) + +function skip_rewrite(mi::Core.MethodInstance)::Bool + mod = mi.def.module + mi.def.file in __skip_files && return true + @lk mi + ft = Base.unwrap_unionall(mi.specTypes).parameters[1] + ft in __skip_rewrite_func_set && return true + + ( + has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) + ) && return true + + if is_base_or_core(mod) + modules = is_base_or_core.(Base.unwrap_unionall(mi.specTypes).parameters[2:end]) + all(modules) && return true + end + return false +end + +include("auto_cf/AutoCF.jl") + +disable_call_with_reactant = false +vv = [] +vb = [] +@inline function typeinf_local(interp::CC.AbstractInterpreter, frame::CC.InferenceState) + @invoke CC.typeinf_local(interp::CC.AbstractInterpreter, frame) +end + +function CC.typeinf_local(interp::ReactantInterpreter, frame::CC.InferenceState) + global tree + mi = frame.linfo + global disable_call_with_reactant + disable_cwr = disable_call_with_reactant ? false : skip_rewrite(mi) + disable_cwr && (disable_call_with_reactant = true) + #@warn disable_cwr mi + disable_call_with_reactant || push!(vb, (mi, CC.copy(frame.src))) + tl = typeinf_local(interp, frame) + disable_call_with_reactant || push!(vv, (mi, CC.copy(frame.src))) + disable_cwr && (disable_call_with_reactant = false) + return tl +end + +function CC.optimize( + interp::ReactantInterpreter, opt::CC.OptimizationState, caller::CC.InferenceResult +) + tree = get(interp.meta_data[].traced_tree_map, mi_key(opt.linfo), nothing) + CC.@timeit "optimizer" ir = if isnothing(tree) || isempty(tree) + CC.run_passes_ipo_safe(opt.src, opt, caller) + else + run_passes_ipo_safe_auto_cf(opt.src, opt, caller, tree) + end + CC.ipo_dataflow_analysis!(interp, ir, caller) + return CC.finish(interp, opt, ir, caller) +end + +function run_passes_ipo_safe_auto_cf( + ci::CC.CodeInfo, + sv::CC.OptimizationState, + caller::CC.InferenceResult, + tree::Tree, + optimize_until=nothing, # run all passes by default +) + __stage__ = 0 # used by @pass + # NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work + CC.@pass "convert" ir = CC.convert_to_ircode(ci, sv) + CC.@pass "slot2reg" ir = CC.slot2reg(ir, ci, sv) + + analysis_reassign_block_id!(tree, ir, ci) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + CC.@pass "compact 1" ir = CC.compact!(ir) + ir = control_flow_transform!(tree, ir) + CC.@pass "Inlining" ir = CC.ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + CC.@pass "compact 2" ir = CC.compact!(ir) + CC.@pass "SROA" ir = CC.sroa_pass!(ir, sv.inlining) + @info sv.linfo + CC.@pass "ADCE" (ir, made_changes) = CC.adce_pass!(ir, sv.inlining) + if made_changes + CC.@pass "compact 3" ir = CC.compact!(ir, true) + end + if CC.is_asserts() + CC.@timeit "verify 3" begin + CC.verify_ir(ir, true, false, CC.optimizer_lattice(sv.inlining.interp)) + CC.verify_linetable(ir.linetable) + end + end + CC.@label __done__ # used by @pass + return ir +end + +lead_to_dynamic_call(@nospecialize(ty)) = begin + isconcretetype(ty) && return false + ty == Union{} && return false + Base.isvarargtype(ty) && return true + (ty <: Type || ty <: Tuple) && return false + return true +end + +# Rewrite type unstable calls to recurse into call_with_reactant to ensure +# they continue to use our interpreter. +function need_rewrite_call(interp, @nospecialize(fn), @nospecialize(args)) + #UnionAll constructor cannot get a singleton type, and are not handled by the call_with_reactant macro: degradate type inference + isnothing(fn) && return false + #ignore constructor + fn isa Type && return false + + ft = typeof(fn) + (ft <: Core.IntrinsicFunction || ft <: Core.Builtin) && return false + ft in __skip_rewrite_func_set && return false + #Base.isstructtype(ft) && return false + if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Core.Compiler) + return false + end + end + #ft isa Type && any(t -> ft <: t, __skip_rewrite_type_constructor_list) && return false + #ft in __skip_rewrite_func_set && return false + + #ft<: typeof(Core.kwcall) && return true + tt = Tuple{ft,args...} + match = CC._findsup(tt, REACTANT_METHOD_TABLE, CC.get_inference_world(interp))[1] + !isnothing(match) && return true + match = CC._findsup(tt, nothing, CC.get_inference_world(interp))[1] + isnothing(match) && return true + startswith( + string(match.method.name), "#(overlay (. Reactant (inert REACTANT_METHOD_TABLE))" + ) && return false + + # Avoid recursively interpreting into methods we define explicitly + # as overloads, which we assume should handle the entirety of the + # translation (and if not they can use call_in_reactant). + isdefined(match.method, :external_mt) && + match.method.external_mt === REACTANT_METHOD_TABLE && + return false + + match.method.file in __skip_files && return false + + #Dynamic dispatch handler + types = if match.method.nospecialize != 0 + match.method.sig + else + mi = CC.specialize_method(match) + mi.specTypes + end + + mask = lead_to_dynamic_call.(Base.unwrap_unionall(types).parameters) + #@error string(ft) mask types + return any(mask) +end + +function CC.abstract_eval_call( + interp::ReactantInterpreter, + e::Expr, + vtypes::Union{CC.VarTable,Nothing}, + sv::CC.AbsIntState, +) + if !(sv isa CC.IRInterpretationState) #during type inference, rewrite dynamic call with call_with_reactant + global disable_call_with_reactant + if !disable_call_with_reactant + argtypes = CC.collect_argtypes(interp, e.args, vtypes, sv) + args = CC.argtypes_to_type(argtypes).parameters + fn = CC.singleton_type(argtypes[1]) + if need_rewrite_call(interp, fn, args[2:end]) + @error fn string(argtypes) sv.linfo + log(interp) && push!( + interp.debug_data[].rewrite_call, + (fn, args[2:end], sv.linfo), #CC.copy(sv.src) + ) + e = Expr(:call, GlobalRef(@__MODULE__, :call_with_reactant), e.args...) + expr = sv.src.code[sv.currpc] + sv.src.code[sv.currpc] = if expr.head == :call + e + else + @assert expr.head == :(=) #CodeInfo slot write + Expr(:(=), expr.args[1], e) + end + end + else + log(interp) && push!( + interp.debug_data[].non_rewrite_call, + (sv.linfo, CC.collect_argtypes(interp, e.args, vtypes, sv)), + ) + end + end + + return @invoke CC.abstract_eval_call( + interp::CC.AbstractInterpreter, + e::Expr, + vtypes::Union{CC.VarTable,Nothing}, + sv::CC.AbsIntState, + ) +end + +using LLVM, LLVM.Interop + +struct CompilerInstance + lljit::LLVM.JuliaOJIT + lctm::LLVM.LazyCallThroughManager + ism::LLVM.IndirectStubsManager +end +const jit = Ref{CompilerInstance}() + +function get_trampoline(job) + (; lljit, lctm, ism) = jit[] + jd = JITDylib(lljit) + + target_sym = String(gensym(string(job.source))) + + # symbol flags (callable + exported) + flags = LLVM.API.LLVMJITSymbolFlags( + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) + + sym = Ref(LLVM.API.LLVMOrcCSymbolFlagsMapPair(mangle(lljit, target_sym), flags)) + + # materialize callback: compile/emit module when symbols requested + function materialize(mr) + JuliaContext() do ctx + ir, meta = GPUCompiler.compile(:llvm, job; validate=false) + runtime_log(GPUCompiler.get_interpreter(job)) && @warn "materialize" job + @lk ir + # Ensure the module's entry has the target name we declared + LLVM.name!(meta.entry, target_sym) + r_symbols = string.(LLVM.get_requested_symbols(mr)) + #expose only the function defined in job + for f in LLVM.functions(ir) + isempty(LLVM.blocks(f)) && continue #declare functions + LLVM.name(f) in r_symbols && continue + LLVM.linkage!(f, LLVM.API.LLVMPrivateLinkage) + end + + #convert global alias to private linkage in order to not be relocatable + for g in LLVM.globals(ir) + ua = LLVM.API.LLVMGetUnnamedAddress(g) + (ua == LLVM.API.LLVMLocalUnnamedAddr || ua == LLVM.API.LLVMNoUnnamedAddr) || + continue + LLVM.isconstant(g) && continue + LLVM.API.LLVMSetUnnamedAddress(g, LLVM.API.LLVMNoUnnamedAddr) + LLVM.linkage!(g, LLVM.API.LLVMPrivateLinkage) + end + # serialize the module IR into a memory buffer + buf = convert(MemoryBuffer, ir) + # deserialize under a thread-safe context and emit via IRCompileLayer + ThreadSafeContext() do ts_ctx + tsm = context!(context(ts_ctx)) do + mod = parse(LLVM.Module, buf) + ThreadSafeModule(mod) + end + + il = LLVM.IRCompileLayer(lljit) + # Emit the ThreadSafeModule for the responsibility mr. + LLVM.emit(il, mr, tsm) + end + end + return nothing + end + + # discard callback (no-op for now) + function discard(jd_arg, sym) + @error "discard" sym + end + + # Create a single CustomMaterializationUnit that declares both entry and target. + # Name it something descriptive (e.g., the entry_sym) + mu = LLVM.CustomMaterializationUnit("MU_" * target_sym, sym, materialize, discard) + + # Define the MU in the JITDylib (declares the symbols as owned by this MU) + LLVM.define(jd, mu) + + # Lookup the entry address (this will trigger materialize if needed) + addr = lookup(lljit, target_sym) + return addr +end +import GPUCompiler: deferred_codegen_jobs + +function ccall_deferred(ptr::Ptr{Cvoid}) + return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +""" + Reactant.REDUB_ARGUMENTS_NAME + +The variable name bound to `call_with_reactant`'s tuple of arguments in its +`@generated` method definition. + +This binding can be used to manually reference/destructure `call_with_reactants` arguments + +This is required because user arguments could have a name which clashes with whatever name we choose for +our argument. Thus we gensym to create it. + +This originates from + https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 + https://github.com/JuliaGPU/GPUCompiler.jl/blob/master/examples/jit.jl +""" +const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") + +function deferred_call_with_reactant( + world::UInt, source::LineNumberNode, self, @nospecialize(args) +) + f = args[1] + tt = Tuple{f,args[2:end]...} + match = CC._findsup(tt, REACTANT_METHOD_TABLE, world) + match = isnothing(match[1]) ? CC._findsup(tt, nothing, world) : match + + stub = Core.GeneratedFunctionStub( + identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() + ) + + if isnothing(match[1]) + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) + return stub(world, source, method_error) + end + + mi = CC.specialize_method(match[1]) + + target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=false) + config = CompilerConfig( + target, + CompilerParams(); + kernel=false, + libraries=false, + toplevel=true, + validate=false, + strip=false, + optimize=true, + entry_abi=:func, + ) + job = CompilerJob(mi, config, world) + interp = GPUCompiler.get_interpreter(job) + + ci = CC.typeinf_ext(interp, mi) + @assert !isnothing(ci) + rt = ci.rettype + @lk ci job + runtime_log(interp) && @warn "ci rt" job ci rt + + addr = get_trampoline(job) + trampoline = pointer(addr) + id = Base.reinterpret(Int, trampoline) + + deferred_codegen_jobs[id] = job + + #build CodeInfo directly + code_info = begin + ir = CC.IRCode() + src = @ccall jl_new_code_info_uninit()::Ref{CC.CodeInfo} + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = UInt64 + CC.ir_to_codeinf!(src, ir) + end + + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] + function push_inst!(inst) + push!(overdubbed_code, inst) + push!(overdubbed_codelocs, code_info.codelocs[1]) + return Core.SSAValue(length(overdubbed_code)) + end + code_info.edges = Core.MethodInstance[job.source] + code_info.rettype = rt + + ptr = push_inst!(Expr(:call, :ccall_deferred, trampoline)) + + fn_args = [] + for i in 2:length(args) + named_tuple_ssa = Expr( + :call, Core.GlobalRef(Core, :getfield), Core.SlotNumber(2), i + ) + arg = push_inst!(named_tuple_ssa) + push!(fn_args, arg) + end + + f_arg = push_inst!(Expr(:call, Core.GlobalRef(Core, :getfield), Core.SlotNumber(2), 1)) + + args_vec = push_inst!( + Expr(:call, GlobalRef(Base, :getindex), GlobalRef(Base, :Any), fn_args...) + ) + + runtime_log(interp) && push_inst!( + Expr( + :call, + GlobalRef(Base, :println), + "before call_with_reactant ", + f_arg, + "(", + args_vec, + ")", + ), + ) + preserve = push_inst!(Expr(:gc_preserve_begin, args_vec)) + args_vec = push_inst!(Expr(:call, GlobalRef(Base, :pointer), args_vec)) + n_args = length(fn_args) + + #Use ccall internal directly to call the wrapped llvm function + result = push_inst!( + Expr( + :foreigncall, + ptr, + Ptr{rt}, + Core.svec(Any, Ptr{Any}, Int), + 0, + QuoteNode(:ccall), + f_arg, + args_vec, + n_args, + n_args, + args_vec, + f_arg, + ), + ) + + result = push_inst!(Expr(:call, GlobalRef(Base, :unsafe_pointer_to_objref), result)) + push_inst!(Expr(:gc_preserve_end, preserve)) + result = push_inst!(Expr(:call, GlobalRef(@__MODULE__, :barrier), result, rt)) + runtime_log(interp) && push_inst!( + Expr( + :call, + GlobalRef(Base, :println), + "after call_with_reactant ", + f_arg, + " ", + result, + ), + ) + push_inst!(Core.ReturnNode(result)) + + code_info.min_world = typemin(UInt) + code_info.max_world = typemax(UInt) + code_info.slotnames = Any[:call_with_reactant_, REDUB_ARGUMENTS_NAME] + code_info.slotflags = UInt8[0x00, 0x00] + code_info.code = overdubbed_code + code_info.codelocs = overdubbed_codelocs + code_info.ssavaluetypes = length(overdubbed_code) + code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] + return code_info +end + +@eval function call_with_reactant($(REDUB_ARGUMENTS_NAME)...) + $(Expr(:meta, :generated_only)) + return $(Expr(:meta, :generated, deferred_call_with_reactant)) +end +const jd_main = Ref{Any}() +function init_jit() + lljit = JuliaOJIT() + jd_main[] = JITDylib(lljit) + prefix = LLVM.get_prefix(lljit) + + dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix) + add!(jd_main[], dg) + + es = ExecutionSession(lljit) + + lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es) + ism = LLVM.LocalIndirectStubsManager(triple(lljit)) + + jit[] = CompilerInstance(lljit, lctm, ism) + atexit() do + (; lljit, lctm, ism) = jit[] + dispose(ism) + dispose(lctm) + dispose(lljit) + end +end + +function ir_to_codeinfo!(ir::CC.IRCode)::CC.CodeInfo + code_info = begin + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = Int + CC.ir_to_codeinf!(src, ir) + src.ssavaluetypes = length(src.ssavaluetypes) + src + end + return code_info +end + +struct FakeOc + f::Vector + ci::Vector{CC.CodeInfo} +end + +fake_oc_dict = FakeOc([], []) + +fake_oc(ir::CC.IRCode, return_type) = begin + src = ir_to_codeinfo!(ir) + fake_oc(src, return_type) +end + +function fake_oc(src::CC.CodeInfo, return_type; args=nothing) + @assert !isnothing(current_interpreter[]) + types = isnothing(args) ? src.slottypes[2:end] : args + global fake_oc_dict + index = findfirst(==(src), fake_oc_dict.ci) + !isnothing(index) && return fake_oc_dict.f[index] + + expr = (Expr(:(::), Symbol("arg_$i"), type) for (i, type) in enumerate(types)) + args = Expr(:tuple, (Symbol("arg_$i") for (i, type) in enumerate(types))...) + fn_name = gensym(:fake_oc) + call_expr = Expr(:call, fn_name, expr...) + f_expr = Expr( + :(=), + call_expr, + quote + Reactant.barrier($args, $return_type) + end, + ) + f = @eval @noinline $f_expr + mi = Base.method_instance(f, types) + @assert !isnothing(mi) + mi.def.source = CC.maybe_compress_codeinfo(current_interpreter[], mi, src) + push!(fake_oc_dict.f, f) + push!(fake_oc_dict.ci, src) + return f +end diff --git a/src/Reactant.jl b/src/Reactant.jl index e27c38b2ed..d8c247e1db 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -329,7 +329,7 @@ function __init__() ) end end - + init_jit() return nothing end diff --git a/src/auto_cf/AutoCF.jl b/src/auto_cf/AutoCF.jl new file mode 100644 index 0000000000..a5d94de857 --- /dev/null +++ b/src/auto_cf/AutoCF.jl @@ -0,0 +1,6 @@ +include("debug_utils.jl") +include("new_inference.jl") +include("code_info_mut.jl") +include("code_ir_utils.jl") +include("mlir_utils.jl") +include("code_gen.jl") \ No newline at end of file diff --git a/src/auto_cf/analysis.jl b/src/auto_cf/analysis.jl new file mode 100644 index 0000000000..5b79530441 --- /dev/null +++ b/src/auto_cf/analysis.jl @@ -0,0 +1,55 @@ +@enum UpgradeSlot NoUpgrade UpgradeLocally UpgradeDefinition UpgradeDefinitionGlobal + +@enum State Traced Upgraded Maybe NotTraced + +mutable struct ForStructure + accus::Tuple + header_bb::Int + latch_bb::Int + terminal_bb::Int + body_bbs::Set{Int} + state::State +end + +struct IfStructure + ssa_cond + header_bb::Int + terminal_bb::Int + true_bbs::Set{Int} + false_bbs::Set{Int} + owned_true_bbs::Set{Int} + owned_false_bbs::Set{Int} + legalize::Ref{Bool} #inform that the if traced GotoIfNot can pass type inference +end + +mutable struct SlotAnalysis + slot_stmt_def::Vector{Integer} #0 for argument + slot_bb_usage::Vector{Set{Int}} +end + + +CFStructure = Union{IfStructure,ForStructure} +mutable struct Tree + node::Union{Nothing,Base.uniontypes(CFStructure)...} + children::Vector{Tree} + parent::Ref{Tree} +end + +Base.isempty(tree::Tree) = isnothing(tree.node) && length(tree.children) == 0 + +Base.show(io::IO, t::Tree) = begin + Base.print(io, '(') + Base.show(io, t.node) + Base.print(io, ',') + Base.show(io, t.children) + Base.print(io, ')') +end + +mutable struct Analysis + tree::Tree + domtree::Union{Nothing,Vector{CC.DomTreeNode}} + postdomtree::Union{Nothing,Vector{CC.DomTreeNode}} + slotanalysis::Union{Nothing,SlotAnalysis} + pending_tree::Union{Nothing,Tree} +end + diff --git a/src/auto_cf/code_gen.jl b/src/auto_cf/code_gen.jl new file mode 100644 index 0000000000..bb9dbe3e1a --- /dev/null +++ b/src/auto_cf/code_gen.jl @@ -0,0 +1,708 @@ +#TODO: remove this +returning_type(X) = X +get_traced_type(X) = X + + + +struct TUnitRange{T} + min::Union{T,Reactant.TracedRNumber{T}} + max::Union{T,Reactant.TracedRNumber{T}} +end + +struct TStepRange{T} + min::Union{T,Reactant.TracedRNumber{T}} + step::T #TODO:add support to traced step + max::Union{T,Reactant.TracedRNumber{T}} +end + +#Needed otherwise standard lib defined a more specialized method +(::Colon)(min::Reactant.TracedRNumber{T}, max::Reactant.TracedRNumber{T}) where {T} = TUnitRange(min, max) + +(::Colon)(min::Union{T,Reactant.TracedRNumber{T}}, max::Union{T,Reactant.TracedRNumber{T}}) where {T} = TUnitRange(min, max) +Base.first(a::TUnitRange) = a.min +Base.last(a::TUnitRange) = a.max +@noinline Base.iterate(i::TUnitRange{T}, _::Nothing=nothing) where {T} = CC.inferencebarrier(i)::Union{Nothing,Tuple{Reactant.TracedRNumber{T},Nothing}} + +(::Colon)(min::Union{T,Reactant.TracedRNumber{T}}, step::T, max::Union{T,Reactant.TracedRNumber{T}}) where {T} = TStepRange(min, step, max) +Base.first(a::TStepRange) = a.min +Base.last(a::TStepRange) = a.max +@noinline Base.iterate(i::TStepRange{T}, _::Nothing=nothing) where {T} = CC.inferencebarrier(i)::Union{Nothing,Tuple{Reactant.TracedRNumber{T},Nothing}} + + +#keep using the base iterate for upgraded loop. +Base.iterate(T::Type, args...) = CC.inferencebarrier(args)::T + + + +""" + Hidden + + struct to hide the default print of a type, use to show a CodeIR containing inlined CodeIR + #TODO: add parametric +""" +struct Hidden + value +end + +function Base.show(io::IO, x::Hidden) + print(io, "<$(typeof(x.value))>") +end + + +""" + juliair_to_mlir(ir::Core.Compiler.IRCode, args...) -> Vector + +Execute the `ir` and add a MLIR `return` operation to the traced `ir` return variables. Return all `ir` return variable +TODO: remove masked_traced +`args` must follow types restriction in `ir.argtypes`, otherwise completely break Julia +""" +function juliair_to_mlir(ir::Core.Compiler.IRCode, args...)::Tuple{Tuple{Vararg{Bool}},Tuple} + (ir3, args3) = (ir, args) + @lk ir3 args3 + @warn typeof.(args) + @warn ir.argtypes[2:end] + #Cannot use .<: -> dispatch to Reactant `materialize` + equal = length(args) == length(ir.argtypes[2:end]) + for (a, b) in zip(typeof.(args), ir.argtypes[2:end]) + equal || break + equal = a <: b + end + @assert equal "$(typeof.(args)) \n $(ir.argtypes[2:end])" + f = Core.OpaqueClosure(ir) + result = f(args...) + isnothing(result) && return ((), ()) + result = result isa Tuple ? result : tuple(result) + masked_traced = isa.(result, Union{Reactant.RArray,Reactant.RNumber}) + (masked_traced, result) +end + + +function remove_phi_node_for_body!(ir::CC.IRCode, f::ForStructure) + first_bb = min(f.body_bbs...) + traced_ssa = [] + type_traced_ssa = Type[] + for index in ir.cfg.blocks[first_bb].stmts + stmt = ir.stmts.stmt[index] + isnothing(stmt) && continue #phi node can be simplified during IR compact + stmt isa Core.PhiNode || break + ir.stmts.stmt[index] = stmt.values[1] + type = ir.stmts.type[index] + is_traced(type) || continue + push!(traced_ssa, stmt.values[1]) + push!(type_traced_ssa, type) + end + traced_ssa, type_traced_ssa +end + +""" + apply_transformation!(ir::Core.Compiler.IRCode, if_::IfStructure) + Apply static Julia IR change to `ir` in order to tracing the if defined in `if_`. + Create a call to `jit_if_controlflow` which will during runtime trace the two branch of it following the two extracted IRCode. + TODO: add mutable support like for +""" +function apply_transformation!(ir::Core.Compiler.IRCode, if_::IfStructure) + (; header_bb::Int, terminal_bb::Int, true_bbs::Set{Int}, false_bbs::Set{Int}, owned_true_bbs::Set{Int}, owned_false_bbs::Set{Int}) = if_ + true_phi_ssa = [] + false_phi_ssa = [] + if_returned_types = Type[] + phi_index = [] + #In the last block of if, collect all phi_values + for index in ir.cfg.blocks[terminal_bb].stmts + ir.stmts.stmt[index] isa Core.PhiNode || break + push!(phi_index, index) + phi = ir.stmts.stmt[index] + phi_type::Type = ir.stmts.type[index] + if_returned_type::Union{Type,Nothing} = returning_type(phi_type) #TODO: deal with promotion here + if_returned_type isa Nothing && error("transformation failed") + push!(if_returned_types, if_returned_type) + add_phi_value!(true_phi_ssa, phi, true_bbs) + add_phi_value!(false_phi_ssa, phi, false_bbs) + end + + #map the old argument with the new ones + new_args_dict = Dict() + r1 = nothing + if !isempty(true_bbs) + r1 = extract_multiple_block_ir(ir, true_bbs, new_args_dict, true_phi_ssa) + clear_block_ir!(ir, owned_true_bbs) + end + + r2 = nothing + if !isempty(false_bbs) + r2 = extract_multiple_block_ir(ir, false_bbs, new_args_dict, false_phi_ssa) + clear_block_ir!(ir, owned_false_bbs) + end + + #common arguments for both branch + (value, new_args_v) = vec_args(ir, new_args_dict) + + r1 = isnothing(r1) ? nothing : finish(r1, new_args_v) + r2 = isnothing(r2) ? nothing : finish(r2, new_args_v) + + #remove MethodInstance name (needed for OpaqueClosure) + new_args_v = new_args_v[2:end] + + + cond = cond_ssa(ir, header_bb) + owned_bbs = union(owned_true_bbs, owned_false_bbs) + #Mutate IR + #replace GotoIfNot -> GotoNode + #TODO: can cond be defined before goto? + ssa_goto = terminator_index(ir, header_bb) + change_stmt!(ir, terminator_index(ir, max(owned_bbs...)), Core.GotoNode(terminal_bb), Any) + change_stmt!(ir, ssa_goto, nothing, Nothing) + change_stmt!(ir, ssa_goto - 1, nothing, Nothing) + + #PhiNodes simplifications: + n_result = length(phi_index) + + + new_phi = [] + for phi_i in phi_index + removed_index = [] + phi = ir.stmts.stmt[phi_i] + for (i, edge) in enumerate(phi.edges) + edge in owned_bbs || continue + push!(removed_index, i) + end + push!(new_phi, length(phi.edges) - length(removed_index) == 0 ? nothing : removed_index) + end + + + sign = (Reactant.TracedRNumber{Bool}, Hidden, Hidden, Int, new_args_v...) + @error sign + w = Reactant.current_interpreter[].world + @error w + mi = method_instance(jit_if_controlflow, sign, w) + isnothing(mi) && error("invalid Method Instance") + + @lk r1 r2 + @assert(!isnothing(cond)) + all_args = (cond, + Hidden(r1), Hidden(r2), + length(true_phi_ssa), + value...) + @lk all_args sign + expr = Expr( + :invoke, + mi, + GlobalRef(@__MODULE__, :jit_if_controlflow), + all_args... + ) + + #all phi nodes are replaced and return one result: special case: the if can be created in the final block + if all((==).(new_phi, nothing)) && n_result == 1 + change_stmt!(ir, first(phi_index), expr, get_traced_type(returning_type(ir.stmts.type[first(phi_index)]))) + @goto out + end + + if_ssa = if n_result == 1 + ni = Core.Compiler.NewInstruction(expr, get_traced_type(returning_type(only(if_returned_types)))) + if_ssa = Core.Compiler.insert_node!(ir, ssa_goto, ni, false) + else + tuple = Core.Compiler.NewInstruction(expr, Tuple{if_returned_types...}) + Core.Compiler.insert_node!(ir, ssa_goto, tuple, false) + end + + + for (i, removed_index_phi) in enumerate(new_phi) + if isnothing(removed_index_phi) + ir.stmts.stmt[phi_index[i]] = if n_result == 1 + if_ssa + else + Expr(:call, Core.GlobalRef(Base, :getindex), if_ssa, i) + end + else + current_phi = ir.stmts.stmt[phi_index[i]] + isempty(removed_index_phi) && continue + deleteat!(current_phi.edges, removed_index_phi) + deleteat!(current_phi.values, removed_index_phi) + push!(current_phi.edges, header_bb) + #modify phi branch: in the case of several result, get result_i in if definition block + if n_result == 1 + push!(current_phi.values, if_ssa) + else + expr = Expr(:call, Core.GlobalRef(Base, :getindex), if_ssa, i) + ni = Core.Compiler.NewInstruction(expr, Tuple{if_returned_types...}) + result_i = Core.Compiler.insert_node!(ir, ssa_goto, ni, false) + push!(current_phi.values, result_i) + end + end + end + @label out + ir +end + +#TODO: remove +traced_promotion(X) = X + + +runtime_inner_type(e::Union{Reactant.RArray,Reactant.RNumber}) = Reactant.MLIR.IR.type(e.mlir_data) +runtime_inner_type(e) = typeof(e) + +Base.getindex(::Tuple{}, ::Tuple{}) = () + +""" + jit_if_controlflow(cond::Reactant.TracedRNumber{Bool}, true_b::Core.Compiler.IRCode, false_b::Core.Compiler.IRCode, args...) -> Type + +During runtime, create an if MLIR operation from two branches `true_b` `false_b` Julia IRCode using the arguments `args`. +Return either a traced value or a tuple of traced values. + +""" +function jit_if_controlflow(cond::Reactant.TracedRNumber{Bool}, r1, r2, n_result, args...) + tmp_if_op = Reactant.MLIR.Dialects.stablehlo.if_( + cond.mlir_data; true_branch=Reactant.MLIR.IR.Region(), false_branch=Reactant.MLIR.IR.Region(), result_0=[Reactant.MLIR.IR.Type(Nothing)] + ) + + b1 = Reactant.MLIR.IR.Block() + push!(Reactant.MLIR.IR.region(tmp_if_op, 1), b1) + Reactant.MLIR.IR.activate!(b1) + local_args_r1 = deepcopy.(args) + before_r1 = get_mlir_pointer_or_nothing.(local_args_r1) + tr1 = !isnothing(r1.value) ? juliair_to_mlir(r1.value, local_args_r1...)[2] : () + tr1 = upgrade.(tr1) + after_r1 = get_mlir_pointer_or_nothing.(local_args_r1) + masked_muted_r1 = before_r1 .!== after_r1 + Reactant.MLIR.IR.deactivate!(b1) + + + b2 = Reactant.MLIR.IR.Block() + push!(Reactant.MLIR.IR.region(tmp_if_op, 2), b2) + + Reactant.MLIR.IR.activate!(b2) + local_args_r2 = deepcopy.(args) + before_r2 = get_mlir_pointer_or_nothing.(local_args_r2) + tr2 = !isnothing(r2.value) ? juliair_to_mlir(r2.value, deepcopy.(args)...)[2] : () + tr2 = upgrade.(tr2) + after_r2 = get_mlir_pointer_or_nothing.(local_args_r2) + masked_muted_r2 = before_r2 .!== after_r2 + Reactant.MLIR.IR.deactivate!(b2) + + + t1 = typeof.(tr1) + t2 = typeof.(tr2) + + #Assume results types are equal now: TODO: can be probably be relaxed by promoting types (need change to `juliair_to_mlir` and static IRCode Analysis) + @assert t1 == t2 "each branch $t1 $t2 must have the same type" + + #TODO: select special case + + @lk before_r1 before_r2 + @lk args local_args_r1 local_args_r2 + both_mut = (&).(masked_muted_r1, masked_muted_r2) |> collect + masked_unique_muted_r1 = (&).(masked_muted_r1, (!).(both_mut)) |> collect + masked_unique_muted_r2 = (&).(masked_muted_r2, (!).(both_mut)) |> collect + @lk both_mut masked_unique_muted_r1 masked_unique_muted_r2 masked_muted_r1 masked_muted_r2 + tr1_muted = (local_args_r1[masked_unique_muted_r1]..., upgrade.(args[masked_unique_muted_r2])...) + tr2_muted = (upgrade.(args[masked_unique_muted_r1])..., local_args_r2[masked_unique_muted_r2]...) + @lk tr1_muted tr2_muted + + Reactant.MLIR.IR.activate!(b1) + #TODO: promotion here + Reactant.Ops.return_(tr1..., local_args_r1[both_mut]..., tr1_muted...) + Reactant.MLIR.IR.deactivate!(b1) + + Reactant.MLIR.IR.activate!(b2) + #TODO: promotion here + Reactant.Ops.return_(tr2..., local_args_r1[both_mut]..., tr2_muted...) + Reactant.MLIR.IR.deactivate!(b2) + + return_types = Reactant.MLIR.IR.type.(getfield.(tr1, :mlir_data)) + mut_types = Reactant.MLIR.IR.type.(getfield.(local_args_r1[both_mut], :mlir_data)) + mut_types2 = Reactant.MLIR.IR.type.(getfield.(tr1_muted, :mlir_data)) + + @warn return_types + if_op = Reactant.MLIR.Dialects.stablehlo.if_( + cond.mlir_data; + true_branch=Reactant.MLIR.IR.Region(), + false_branch=Reactant.MLIR.IR.Region(), + result_0=Reactant.MLIR.IR.Type[return_types..., mut_types..., mut_types2...] + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(if_op, 1), Reactant.MLIR.IR.region(tmp_if_op, 1) + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(if_op, 2), Reactant.MLIR.IR.region(tmp_if_op, 2) + ) + + results = Vector(undef, length(t1)) + for (i, e) in enumerate(tr1) + traced = deepcopy(e) + traced.mlir_data = Reactant.MLIR.IR.result(if_op, i) #TODO: setmlirdata + results[i] = traced + end + + @lk if_op + + arg_offset = length(t1) + for (i, index) in enumerate(findall((|).(masked_muted_r1, masked_muted_r2))) + Reactant.TracedUtils.set_mlir_data!(args[index], Reactant.MLIR.IR.result(if_op, arg_offset + i)) + end + + Reactant.MLIR.API.mlirOperationDestroy(tmp_if_op.operation) + + #TODO: add a runtime type check here using static analysis + return length(results) == 1 ? only(results) : Tuple(results) +end + + +#remove iterator usage in JuliaIR and keep branch +function remove_iterator(ir::CC.IRCode, bb::Int) + @lk ir bb + terminator_pos = terminator_index(ir.cfg, bb) + cond = ir.stmts.stmt[terminator_pos].cond + cond isa Core.SSAValue || return + iterator_index = cond.id - 2 + iterator_expr = ir.stmts.stmt[iterator_index] + @assert iterator_expr isa Expr && iterator_expr.head == :call && iterator_expr.args[1] == GlobalRef(Base, :iterate) + + iterator_def = iterator_expr.args[end] + for i in iterator_index:iterator_index+2 + change_stmt!(ir, i, nothing, Nothing) + end + iterator_def +end + +function list_phi_nodes_values(ir::CC.IRCode, in_bb::Int32, phi_bb::Int32) + r = [] + for index in ir.cfg.blocks[in_bb].stmts + stmt = ir.stmts.stmt[index] + isnothing(stmt) && continue #phi node can be simplified during IR compact + stmt isa Core.PhiNode || break + index_phi = findfirst(x -> x == phi_bb, stmt.edges) + isnothing(index_phi) && continue + push!(r, stmt.values[index_phi]) + end + r +end + + +function apply_transformation!(ir::CC.IRCode, f::ForStructure) + f.state == Maybe && return + @lk ir f + body_phi_ssa = list_phi_nodes_values(ir, Int32(min(f.body_bbs...)), Int32(f.header_bb)) + terminal_phi_ssa = list_phi_nodes_values(ir, Int32(f.terminal_bb), Int32(f.header_bb)) + #check terminal block Phi nodes and find the incumulators by doing the substraction between terminal body and first body block phi nodes + accumulars_mask = Vector() + for ssa in terminal_phi_ssa + push!(accumulars_mask, ssa in body_phi_ssa) + end + + new_args_dict = Dict() + #TODO: rewrite this: to use terminal_phi_ssa directly + (traced_ssa_for_bodies, traced_ssa_for_bodies_types) = remove_phi_node_for_body!(ir, f) + + ir_back = CC.copy(ir) + @lk ir_back + #iteration to reenter loop + remove_iterator(ir, max(f.body_bbs...)) + + last_bb = max(f.body_bbs...) + results = [] + for index in ir.cfg.blocks[f.terminal_bb].stmts + stmt = ir.stmts.stmt[index] + stmt isa Core.PhiNode || break + for (e_index, bb) in enumerate(stmt.edges) + bb == last_bb || continue + push!(results, stmt.values[e_index]) + end + end + body_bbs = f.body_bbs + @lk ir body_bbs new_args_dict results traced_ssa_for_bodies traced_ssa_for_bodies_types + #TODO: replace result with terminal_phi_ssa + loop_body = extract_multiple_block_ir(ir, f.body_bbs, new_args_dict, results).ir + @info "body" loop_body + #value doesn't contain the function name unlike new_args_v + (value, new_args_v) = vec_args(ir, new_args_dict) + iterator_index = 0 + for (i, t) in enumerate(new_args_v) + (t isa Union && Nothing in Base.uniontypes(t)) || continue + iterator_index = i - 1 + break + end + #iteration to enter the loop + iterator_def = remove_iterator(ir, f.header_bb) + #fix cfg + + change_stmt!(ir, terminator_index(ir, f.header_bb), Core.GotoNode(f.terminal_bb), Any) + change_stmt!(ir, terminator_index(ir, last_bb), Core.GotoNode(f.terminal_bb), Any) + clear_block_ir!(ir, f.body_bbs) + + t = if iterator_def isa QuoteNode #constant iterator: + #IMPORTANT: object must be copied: QuoteNode.value cannot be reused in Opaque Closure + iterator_def = copy(iterator_def.value) + typeof(iterator_def) + else + ir.stmts.type[iterator_def.id] + end + + @lk value new_args_v terminal_phi_ssa t + while_output_type = (typeof_ir(ir, ssa) for ssa in terminal_phi_ssa) + #first element in new_args_v/ value is the iterator first step: only the iterator definition is needed + sign = (t, Hidden, Int, Vector{Bool}, while_output_type..., new_args_v[2:end]...) + @lk sign + mi = method_instance(jit_loop_controlflow, CC.widenconst.(sign), current_interpreter[].world) + isnothing(mi) && error("invalid Method Instance") + expr = Expr( + :invoke, + mi, + GlobalRef(Main, :jit_loop_controlflow), + iterator_def, + Hidden(loop_body), + iterator_index, + accumulars_mask, + terminal_phi_ssa..., + value..., + ) + @warn expr + phi_index = [] + #In the last block of for, collect all phi_values + for index in ir.cfg.blocks[f.terminal_bb].stmts + ir.stmts.stmt[index] isa Core.PhiNode || break + push!(phi_index, index) + end + + if length(phi_index) == 0 + CC.insert_node!(ir, CC.SSAValue(start_index(ir, f.terminal_bb)), Core.Compiler.NewInstruction(expr, Any), false) + elseif length(phi_index) == 1 + phi = only(phi_index) + change_stmt!(ir, phi, expr, returning_type(ir.stmts.type[phi])) + else + while_ssa = Core.SSAValue(terminator_index(ir, f.header_bb) - 1) + change_stmt!(ir, while_ssa.id, expr, Tuple{while_output_type...}) + for (i, index) in enumerate(phi_index) + ir.stmts.stmt[index] = Expr(:call, Core.GlobalRef(Base, :getindex), while_ssa, i) + end + + end +end + +function get_mlir_pointer_or_nothing(x::Union{Reactant.TracedRNumber,Reactant.TracedRArray}) + Reactant.TracedUtils.get_mlir_data(x).value +end + +get_mlir_pointer_or_nothing(_) = nothing + +#iterator for_body iterator_type n_init traced_ssa_for_bodies args +function jit_loop_controlflow(iterator, for_body::Hidden, iterator_index::Int, accu_mask::Vector{Bool}, args_full...) + #only support UnitRange atm + (start, stop, iterator_begin, iter_step) = if iterator isa Union{Base.OneTo,UnitRange,TUnitRange,StepRange,TStepRange} + start = first(iterator) + stop = last(iterator) + iter_step = iterator isa Union{StepRange,TStepRange} ? iterator.step : 1 + (start, stop, Reactant.Ops.constant(start), iter_step) + else + error("unsupported type $(typeof(iterator))") + end + + start = first(iterator) + stop = last(iterator) + @lk start + iterator_ = is_traced(typeof(start)) ? start : Reactant.TracedRNumber{typeof(start)}((), nothing) + n_accu = length(accu_mask) + @lk n_accu args_full accu_mask iterator_index + accus = args_full[1:n_accu] + julia_use_iter = iterator_index != 0 + args = args_full[(n_accu+1):end] + @lk args accus + tmp_while_op = Reactant.MLIR.Dialects.stablehlo.while_( + Reactant.MLIR.IR.Value[]; + cond=Reactant.MLIR.IR.Region(), + body=Reactant.MLIR.IR.Region(), result_0=Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type.(accus)...] + ) + + mlir_loop_args = Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type(iterator_), Reactant.Ops.mlir_type.(accus)...] + cond = Reactant.MLIR.IR.Block(mlir_loop_args, [Reactant.MLIR.IR.Location() + for _ in mlir_loop_args]) + push!(Reactant.MLIR.IR.region(tmp_while_op, 1), cond) + + @lk cond mlir_loop_args + + Reactant.MLIR.IR.activate!(cond) + Reactant.Ops.activate_constant_context!(cond) + t1 = deepcopy(iterator_) + Reactant.TracedUtils.set_mlir_data!(t1, Reactant.MLIR.IR.argument(cond, 1)) + r = iter_step > 0 ? t1 < stop : t1 > stop + Reactant.Ops.return_(r) + Reactant.Ops.deactivate_constant_context!(cond) + Reactant.MLIR.IR.deactivate!(cond) + + body = Reactant.MLIR.IR.Block(mlir_loop_args, [Reactant.MLIR.IR.Location() + for _ in mlir_loop_args]) + push!(Reactant.MLIR.IR.region(tmp_while_op, 2), body) + + for (i, arg) in enumerate(accus) + arg_ = deepcopy(arg) + Reactant.TracedUtils.set_mlir_data!(arg_, Reactant.MLIR.IR.argument(body, i + 1)) + end + + #TODO: add try finally + Reactant.MLIR.IR.activate!(body) + Reactant.Ops.activate_constant_context!(body) + iter_reactant = deepcopy(iterator_) + Reactant.TracedUtils.set_mlir_data!(iter_reactant, Reactant.MLIR.IR.argument(body, 1)) + + @lk iter_reactant args for_body + + block_accus = [] + for j in eachindex(args) + if args[j] isa Union{Reactant.TracedRNumber,Reactant.TracedRArray} + for k in eachindex(accus) + (isnothing(args[j]) || isnothing(accus[k])) && continue + args[j].mlir_data == accus[k].mlir_data || continue + tmp = Reactant.TracedUtils.set_mlir_data!(deepcopy(args[j]), Reactant.MLIR.IR.argument(body, 1 + k)) + push!(block_accus, tmp) + @goto break2 + end + end + push!(block_accus, args[j]) + @label break2 + end + + pointer_before = get_mlir_pointer_or_nothing.(args) + + if iterator_index != 0 + block_accus[iterator_index] = (iter_reactant, nothing) + end + + @lk block_accus + + t = juliair_to_mlir(for_body.value, block_accus...)[2] + + #we use a local defined variable inside of for outside: the argument must be added to while operation (cond and body) + + pointer_after = get_mlir_pointer_or_nothing.(args) + + muted_mask = collect(pointer_before .!= pointer_after) + args_muted = args[muted_mask] + + for (am, old_value) in zip(args_muted, pointer_before[muted_mask]) + type = Reactant.MLIR.IR.type(am.mlir_data) + Reactant.MLIR.IR.push_argument!(cond, type) + new_value = Reactant.MLIR.IR.push_argument!(body, type) + @warn "changed $(Reactant.MLIR.IR.Value(old_value)) to $new_value" + @lk new_value + change_value!(Reactant.MLIR.IR.Value(old_value), new_value, body) + end + + @lk pointer_before pointer_after t body args_muted accus + + iter_next = iter_step > 0 ? iter_reactant + iter_step : iter_reactant - abs(iter_step) + Reactant.Ops.return_(iter_next, t..., args_muted...) + + Reactant.MLIR.IR.deactivate!(body) + Reactant.Ops.deactivate_constant_context!(body) + + @lk iterator_begin + + while_op = Reactant.MLIR.Dialects.stablehlo.while_( + Reactant.MLIR.IR.Value[Reactant.TracedUtils.get_mlir_data(iterator_begin), Reactant.TracedUtils.get_mlir_data.(accus)..., Reactant.MLIR.IR.Value.(pointer_before[muted_mask])...]; + cond=Reactant.MLIR.IR.Region(), + body=Reactant.MLIR.IR.Region(), result_0=Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type(iterator_begin), Reactant.Ops.mlir_type.(accus)..., Reactant.Ops.mlir_type.(args_muted)...] + ) + + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(while_op, 1), Reactant.MLIR.IR.region(tmp_while_op, 1) + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(while_op, 2), Reactant.MLIR.IR.region(tmp_while_op, 2) + ) + + init_mlir_result_offset = max(1, julia_use_iter ? 1 : 0) #TODO: suspicions probably min + n = init_mlir_result_offset + length(accus) + for (i, muted) in enumerate(args_muted) + Reactant.TracedUtils.set_mlir_data!(muted, Reactant.MLIR.IR.result(while_op, n + i)) + end + + Reactant.MLIR.API.mlirOperationDestroy(tmp_while_op.operation) + + results = [] + for (i, accu) in enumerate(accus) + r_i = deepcopy(accu) #TODO: is this needed? + Reactant.TracedUtils.set_mlir_data!(r_i, Reactant.MLIR.IR.result(while_op, i + init_mlir_result_offset)) + @info "r_i" r_i + push!(results, r_i) + end + + #loop can contain non accus which are returned + # x = 5 + # for i in 1:10 + # x = 2 + # end + # x + + return length(results) == 1 ? only(results) : Tuple(results) +end + + +function post_order(tree::Tree) + v = [] + for c in tree.children + push!(v, post_order(c)...) + end + push!(v, tree.node) +end + +""" + control_flow_transform!(an::Analysis, ir::Core.Compiler.IRCode) -> Core.Compiler.IRCode + apply changes to traced control flow, `ir` argument is not valid anymore +""" +function control_flow_transform!(tree::Tree, ir::CC.IRCode)::CC.IRCode + for node in post_order(tree)[1:end-1] + apply_transformation!(ir, node) + ir = CC.compact!(ir, false) + end + CC.compact!(ir, true) +end + + +#= + analysis_reassign_block_id!(an::Analysis, ir::Core.IRCode, src::Core.CodeInfo) + slot2reg can change type infered CodeInfo CFG by removing non-reachable block, + ControlFlow analysis use blocks information and must be shifted + +=# +function analysis_reassign_block_id!(tree::Tree, ir::CC.IRCode, src::CC.CodeInfo) + isempty(tree) && return false + cfg = CC.compute_basic_blocks(src.code) + length(ir.cfg.blocks) == length(cfg.blocks) && return false + @info "rewrite analysis blocks" + new_block_map = [] + i = 0 + for block in cfg.blocks + unreacheable_block = all(x->src.ssavaluetypes[x] === Union{}, block.stmts) + i = unreacheable_block ? i : i + 1 + push!(new_block_map, i) + end + @info new_block_map + function reassign_tree!(s::Set{Int}) + n = [new_block_map[i] for i in s] + empty!(s) + push!(s, n...) + end + + function reassign_tree!(is::IfStructure) + is.header_bb = new_block_map[is.header_bb] + is.terminal_bb = new_block_map[is.terminal_bb] + reassign_tree!(is.true_bbs) + reassign_tree!(is.false_bbs) + reassign_tree!(is.owned_true_bbs) + reassign_tree!(is.owned_false_bbs) + end + + function reassign_tree!(fs::ForStructure) + fs.header_bb = new_block_map[fs.header_bb] + fs.latch_bb = new_block_map[fs.latch_bb] + fs.terminal_bb = new_block_map[fs.terminal_bb] + reassign_tree!(fs.body_bbs) + end + + function reassign_tree!(t::Tree) + isnothing(t.node) || reassign_tree!(t.node) + for c in t.children + reassign_tree!(c) + end + end + reassign_tree!(tree) + return true +end \ No newline at end of file diff --git a/src/auto_cf/code_info_mut.jl b/src/auto_cf/code_info_mut.jl new file mode 100644 index 0000000000..0eab493b14 --- /dev/null +++ b/src/auto_cf/code_info_mut.jl @@ -0,0 +1,106 @@ +struct ShiftedSSA + e::Int +end + +struct ShiftedCF + e::Int +end + +function offset_stmt!(stmt, index, next_bb = true) + if stmt isa Expr + Expr( + stmt.head, (offset_stmt!(a, index) for a in stmt.args)...) + elseif stmt isa Core.ReturnNode + Core.ReturnNode(offset_stmt!(stmt.val, index)) + elseif stmt isa Core.SSAValue + Core.SSAValue(offset_stmt!(ShiftedSSA(stmt.id), index)) + elseif stmt isa Core.GotoIfNot + Core.GotoIfNot(offset_stmt!(stmt.cond, index), offset_stmt!(ShiftedCF(stmt.dest), index, next_bb)) + elseif stmt isa Core.GotoNode + Core.GotoNode(offset_stmt!(ShiftedCF(stmt.label), index, next_bb)) + elseif stmt isa ShiftedSSA + stmt.e + (stmt.e < index ? 0 : 1) + elseif stmt isa ShiftedCF + stmt.e + (stmt.e < index + (next_bb ? 1 : 0) ? 0 : 1) + else + stmt + end +end + +#insert stmt in frame after index +function add_instruction!(frame, index, stmt; type=CC.NotFound(), next_bb = true) + add_instruction!(frame.src, index, stmt; type, next_bb) + frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: more fine graine change here + insert!(frame.stmt_info, index + 1, CC.NoCallInfo()) + insert!(frame.stmt_edges, index + 1, nothing) + insert!(frame.handler_at, index + 1, (0,0)) + frame.cfg = CC.compute_basic_blocks(frame.src.code) + Core.SSAValue(index + 1) +end + +function modify_instruction!(frame, index, stmt) + frame.src.code[index] = stmt + frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: refine this +end + +""" + add_instruction!(ir::CC.CodeInfo, index, stmt; next_bb) + +""" +function add_instruction!(ir::CC.CodeInfo, index, stmt; type=CC.NotFound(), next_bb=true) + for (i, c) in enumerate(ir.code) + ir.code[i] = offset_stmt!(c, index + 1, next_bb) + end + insert!(ir.code, index + 1, stmt) + insert!(ir.codelocs, index + 1, 0) + insert!(ir.ssaflags, index + 1, 0x00000000) + if ir.ssavaluetypes isa Int + ir.ssavaluetypes = ir.ssavaluetypes + 1 + else + insert!(ir.ssavaluetypes, index + 1, type) + end +end + + +function create_slot!(ir::CC.CodeInfo)::Core.SlotNumber + push!(ir.slotflags, 0x00) + push!(ir.slotnames, Symbol("")) + Core.SlotNumber(length(ir.slotflags)) +end + +function create_slot!(frame)::Core.SlotNumber + push!(frame.slottypes, Union{}) + for s in frame.bb_vartables + isnothing(s) && continue + push!(s, CC.VarState(Union{}, true)) + end + create_slot!(frame.src) +end + +add_slot_change!(ir::CC.CodeInfo, index, old_slot::Int) = add_slot_change!(ir, index, Core.SlotNumber(old_slot)) + +function add_slot_change!(ir::CC.CodeInfo, index, old_slot::Core.SlotNumber) + push!(ir.slotflags, 0x00) + push!(ir.slotnames, Symbol("")) + new_slot = Core.SlotNumber(length(ir.slotflags)) + add_instruction!(frame, index, Expr(:(=), new_slot, Expr(:call, GlobalRef(@__MODULE__, :upgrade), old_slot))) + update_ir_new_slot(ir, index, old_slot, new_slot) +end + +function update_ir_new_slot(ir, index, old_slot, new_slot) + for i in index+2:length(ir.code) #TODO: probably need to refine this + ir.code[i] = replace_slot_stmt(ir.code[i], old_slot, new_slot) + end +end + +function replace_slot_stmt(stmt, old_slot, new_slot) + if stmt isa Core.NewvarNode + stmt + elseif stmt isa Expr + Expr(stmt.head, (replace_slot_stmt(e, old_slot, new_slot) for e in stmt.args)...) + elseif stmt isa Core.SlotNumber + stmt == old_slot ? new_slot : stmt + else + stmt + end +end \ No newline at end of file diff --git a/src/auto_cf/code_ir_utils.jl b/src/auto_cf/code_ir_utils.jl new file mode 100644 index 0000000000..2c1a24a688 --- /dev/null +++ b/src/auto_cf/code_ir_utils.jl @@ -0,0 +1,324 @@ +""" + method_instance(f::Function, sign::Tuple{Vararg{Type}}, world) -> Union{Base.MethodInstance, Nothing} + +Same as `Base.method_instance` except it can work in generated function such as `call_with_reactant` +""" +function method_instance(f::Function, sign::Tuple{Vararg{Type}}, world) + tt = Base.signature_type(f, sign) + match, _ = Core.Compiler._findsup(tt, nothing, world) + isnothing(match) && return nothing + mi = Core.Compiler.specialize_method(match) + return mi +end + +""" + change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, stmt, return_type::Type) -> Core.Compiler.Instruction + +Change the `ir` at position `ssa` by the statement `stmt` with a `return_type` +TODO: when stmt is the terminator: Goto -> nothing : must update cfg +""" +function change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, stmt, return_type::Type) + return Core.Compiler.inst_from_newinst!( + ir[Core.SSAValue(ssa)], Core.Compiler.NewInstruction(stmt, return_type), Int32(0), UInt32(0) + ) +end + + +""" + change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, goto::Core.GotoNode, return_type::Type) -> Core.Compiler.Instruction +Specialization of [`change_stmt!`](@ref) for `Core.GotoNode` to deal with control flow graph changes +""" +function change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, goto::Core.GotoNode, return_type::Type) + bb::Int64 = Core.Compiler.block_for_inst(ir, ssa) + succs = ir.cfg.blocks[bb].succs + empty!(succs) + push!(succs, goto.label) + push!(ir.cfg.blocks[goto.label].preds, bb) + @invoke change_stmt!(ir, ssa, goto::Any, return_type) +end + +""" + clear_block_ir!(ir::Core.Compiler.IRCode, blocks::Set{Int}) + Replace in BB `blocks` of `ir` each instruction by nothing +""" +function clear_block_ir!(ir::Core.Compiler.IRCode, blocks::Set{Int}) + for block in blocks + stmt_range::Core.Compiler.StmtRange = ir.cfg.blocks[block].stmts + (f, l) = (stmt_range |> first, stmt_range |> last) + for i in f:l + change_stmt!(ir, i, nothing, Nothing) + end + end +end + + +""" + type_from_ssa(ir::Core.Compiler.IRCode, args::Dict, v::Vector)::Vector + For each stmt in `v` in `ir` get its type +""" +function type_from_ssa(ir::Core.Compiler.IRCode, args::Dict, v) + [ + begin + if e isa Core.SSAValue + ir.stmts.type[e.id] + elseif e isa Core.Argument + args[e][2] + else + typeof(e) + end + + end + for e in v + ] +end + +""" + apply_map(array, block_map::Dict)::Vector + For each element of `array`, get the value associated in the dictionnary `block_map` +""" +function apply_map(array, block_map) + [block_map[a] for a in array if haskey(block_map, a)] +end + + +""" + new_cfg(ir, to_extract::Vector, block_map)::Core.Compiler.CFG + Get the new CFG of `ir` after the extraction of `to_extract` blocks +""" +function new_cfg(ir, to_extract, block_map) + n = 1 + bbs = Core.Compiler.BasicBlock[] + index = Int64[] + for b in to_extract + bb = ir.cfg.blocks[b] + (; start, stop) = bb.stmts + diff = stop - start + push!(bbs, Core.Compiler.BasicBlock(Core.Compiler.StmtRange(n, diff + n), + apply_map(bb.preds, block_map), + apply_map(bb.succs, block_map))) + n += diff + 1 + push!(index, n) + end + Core.Compiler.CFG(bbs, index) +end + +""" + WipExtracting + struct used for an extracted IRCode which is not fully constructed +""" +struct WipExtracting + ir::Core.Compiler.IRCode +end + +""" + is_a_terminator(stmt) + Check if `stmt` is a terminator +""" +function is_a_terminator(stmt) + stmt isa Union{Core.GotoNode,Core.ReturnNode,Core.GotoIfNot} +end + + +""" + offset_stmt!(dict::Dict, stmt, offset::Int, ir::Core.Compiler.IRCode, bb_map) + internal recursive function of [`extract_multiple_block_ir`](@ref) to shift SSAValue/Argument/BasicBlock in `ir` +""" +function offset_stmt!(dict::Dict, stmt, offset::Dict, ir::Core.Compiler.IRCode, bb_map) + if stmt isa Expr + Expr( + stmt.head, (offset_stmt!(dict, a, offset, ir, bb_map) for a in stmt.args)...) + elseif stmt isa Core.Argument + tmp = Core.Argument(length(dict) + 2) + get!(dict, stmt, (tmp, ir.argtypes[stmt.n]))[1] + elseif stmt isa Core.ReturnNode + Core.ReturnNode(offset_stmt!(dict, stmt.val, offset, ir, bb_map)) + elseif stmt isa Core.SSAValue + stmt_bb = Core.Compiler.block_for_inst(ir, stmt.id) + if stmt_bb in keys(offset) #TODO: remove? && stmt.id > offset[stmt_bb] + Core.SSAValue(stmt.id - offset[stmt_bb]) + else + #the stmt is transformed to an IR argument + tmp = Core.Argument(length(dict) + 2) + get!(dict, stmt, (tmp, ir.stmts.type[stmt.id]))[1] + end + elseif stmt isa Core.GotoNode + Core.GotoNode(get(bb_map, stmt.label, 0)) + elseif stmt isa Core.GotoIfNot + Core.GotoIfNot(offset_stmt!(dict, stmt.cond, offset, ir, bb_map), get(bb_map, stmt.dest, 0)) + elseif stmt isa Core.PhiNode + Core.PhiNode(Int32[bb_map[edge] for edge in stmt.edges], Any[offset_stmt!(dict, value, offset, ir, bb_map) for value in stmt.values]) + elseif stmt isa Core.PiNode + Core.PiNode(offset_stmt!(dict, stmt.val, offset, ir, bb_map), stmt.typ) + else + stmt + end +end + +""" + extract_multiple_block_ir(ir, to_extract_set::Set, args::Dict, new_returns::Vector)::WipExtracting + Extract from `ir` a list of blocks `to_extract_set`, creating an new independant IR containing only these blocks. + All unlinked SSA are added to the `args` dictionnary and all values of `new_returns` are returned by the new IR. +""" +function extract_multiple_block_ir(ir::Core.Compiler.IRCode, to_extract_set::Set{Int}, args::Dict, new_returns::Vector)::WipExtracting + @assert isempty(ir.new_nodes.stmts) + to_extract = sort(collect(to_extract_set)) + #for each extracted basic block, get the new offset. + #useful to deal with non-contiguous extraction because in this case, the offset doesn't follow `ir` block offset anymore + bb_offset::Dict{Int,Int} = Dict() + cumulative_offset = (ir.cfg.blocks[first(to_extract)].stmts |> first) - 1 + new_n_stmt = 0 + for bb in minimum(to_extract):maximum(to_extract) + n_stmt = ir.cfg.blocks[bb].stmts |> length + if bb in to_extract + bb_offset[bb] = cumulative_offset + new_n_stmt += n_stmt + else + cumulative_offset += n_stmt + end + end + + block_map = Dict() + for (i, b) in enumerate(to_extract) + block_map[b] = i + end + + cfg = new_cfg(ir, to_extract, block_map) + + f = ir.cfg.blocks[first(to_extract)].stmts |> first + l = ir.cfg.blocks[last(to_extract)].stmts |> last + + #PhiNode uses the global IR, either shift it or add it to the new IR argument + for (i, rb) in enumerate(new_returns) + rb isa Union{Core.SSAValue,Core.Argument} || continue + new_returns[i] = offset_stmt!(args, rb, bb_offset, ir, block_map) + end + + #recreate instruction_stream of the block + instruction_stream = Core.Compiler.InstructionStream(new_n_stmt) + dico = Dict() + new_stmt = 0 + for bb in to_extract + range_bb = ir.cfg.blocks[bb].stmts[[1, end]] + for old_stmt in range_bb[1]:range_bb[2] + new_stmt += 1 + Core.Compiler.setindex!(instruction_stream, ir.stmts[old_stmt], new_stmt) #TODO: check if needed + #ssa offset + instruction_stream.stmt[new_stmt] = offset_stmt!(args, ir.stmts.stmt[old_stmt], bb_offset, ir, block_map) + #line_info + line_info = ir.stmts.line[old_stmt] + line_info == 0 && continue + instruction_stream.line[new_stmt] = get!(dico, line_info, length(dico) + 1) + end + end + + linetable = ir.linetable[sort(collect(keys(dico)))] + linetable = [Core.LineInfoNode(l.module, l.method, l.file, l.line, Int32(0)) for l in linetable] + #Build the new IR argtypes from args dictionnary + (_, argtypes) = vec_args(ir, args) + new_ir = Core.Compiler.IRCode(instruction_stream, cfg, linetable, argtypes, Expr[], Core.Compiler.VarState[]) + + #JuliaIR block can end without a terminator + has_terminator = is_a_terminator(instruction_stream.stmt[end]) + + n_ssa = length(instruction_stream) + retu = if length(new_returns) > 1 + tuple = Core.Compiler.NewInstruction(Expr(:call, Core.GlobalRef(Core, :tuple), new_returns...), Tuple{type_from_ssa(new_ir, args, new_returns)...}) + Core.Compiler.insert_node!(new_ir, Core.Compiler.SSAValue(n_ssa), tuple, !has_terminator) + else + length(new_returns) == 1 ? only(new_returns) : nothing + end + + if has_terminator + change_stmt!(new_ir, n_ssa, Core.ReturnNode(retu), Nothing) + else + terminator = Core.Compiler.NewInstruction(Core.ReturnNode(retu), Nothing) + Core.Compiler.insert_node!(new_ir, Core.Compiler.SSAValue(n_ssa), terminator, true) + end + WipExtracting(Core.Compiler.compact!(new_ir, true)) +end + +function mlir_type(x) + return Reactant.MLIR.IR.TensorType(size(x), Reactant.MLIR.IR.Type(Reactant.unwrapped_eltype(x))) +end + + +""" + vec_args(ir::Core.Compiler.IRCode, new_args::Dict)::Vector + Construct args Vector from `new_args` dictionnary +""" +function vec_args(ir::Core.Compiler.IRCode, new_args::Dict) + argtypes = Vector(undef, length(new_args) + 1) + argtypes[1] = Core.Const("opaque") + value = Vector(undef, length(new_args)) + for (arg, index) in new_args + value[index[1].n-1] = arg + argtypes[index[1].n] = if arg isa Core.Argument #TODO: reuse function + index[2] + else + ir.stmts.type[arg.id] + end + end + (value, argtypes) +end + +""" + typeof_ir(ir::CC.IRCode, e::Union{Core.Argument, Core.SSAValue}) + Return the type of a stmt in `ir` + TODO: replace by CC.argextype +""" +function typeof_ir(ir::CC.IRCode, e::Union{Core.Argument, Core.SSAValue}) + if e isa Core.Argument + ir.argtypes[e.n] + else + ir.stmts.type[e.id] + end +end + + +""" + finish(wir::WipExtracting, new_args::Vector)::Code.Compiler.IRCode + + Constructing the extracted IR by applying the full arguments list +""" +function finish(wir::WipExtracting, new_args::Vector) + (; ir) = wir + empty!(ir.argtypes) + append!(ir.argtypes, new_args) + ir +end + + +""" + add_phi_value!(v::Vector, phi::Core.PhiNode, edge::Set{Int}) + + Add `Core.PhiNode` values to `v` for each `edge` in the set. +""" +function add_phi_value!(v::Vector, phi::Core.PhiNode, edge::Set{Int}) + for (i, e) in enumerate(phi.edges) + e in edge || continue + push!(v, phi.values[i]) + end +end + + +""" + cond_ssa(ir::CC.IRCode, bb::Int) + + Return the SSA value in a traced GotoIfNot +""" +function cond_ssa(ir::CC.IRCode, bb::Int) + ti = terminator_index(ir, bb) + terminator = ir.stmts.stmt[ti] + terminator isa Core.GotoIfNot || return + protection = ir.stmts.stmt[terminator.cond.id] + (protection isa Expr && protection.head == :call && protection.args[1] == Core.GlobalRef(@__MODULE__, :traced_protection)) || return + protection.args[2] +end + +""" + check_integrity(ir::CC.IRCode)::Bool + check if `unreachable` is present in the IR, return true if none +""" +function check_integrity(ir::CC.IRCode)::Bool + !any(ir.stmts.stmt .== [Core.ReturnNode()]) +end \ No newline at end of file diff --git a/src/auto_cf/debug_utils.jl b/src/auto_cf/debug_utils.jl new file mode 100644 index 0000000000..d8ad1863a4 --- /dev/null +++ b/src/auto_cf/debug_utils.jl @@ -0,0 +1,26 @@ +macro stop(n::Int) + u = :counter #gensym() + e = esc(u) + quote + isdefined(@__MODULE__, $(QuoteNode(u))) || global $e = $n + global $e + $e<2 && error("stop") + $e -= 1 + end +end + + + +#leak each argument to a global variable and store each instance of it +macro lks(args...) + nargs = [ Symbol(string(arg) * "s") for arg in args] + quote + $([:( + let val = $(esc(p)) + isdefined(@__MODULE__, $(QuoteNode(n))) || global $(esc(n)) = [] + global $(esc(n)) + push!($(esc(n)), val) + end + ) for (p,n) in zip(args, nargs)]...) + end +end \ No newline at end of file diff --git a/src/auto_cf/mlir_utils.jl b/src/auto_cf/mlir_utils.jl new file mode 100644 index 0000000000..059dc5f315 --- /dev/null +++ b/src/auto_cf/mlir_utils.jl @@ -0,0 +1,23 @@ +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, op::Reactant.MLIR.IR.Operation) + for i in 1:Reactant.MLIR.IR.noperands(op) + Reactant.MLIR.IR.operand(op, i) == from || continue + Reactant.MLIR.IR.operand!(op, i, to) + end + + for i in 1:Reactant.MLIR.IR.nregions(op) + r = Reactant.MLIR.IR.region(op, i) + change_value!(from, to, r) + end +end + +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, region::Reactant.MLIR.IR.Region) + for block in Reactant.MLIR.IR.BlockIterator(region) + change_value!(from, to, block) + end +end + +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, block::Reactant.MLIR.IR.Block) + for op in Reactant.MLIR.IR.OperationIterator(block) + change_value!(from, to, op) + end +end \ No newline at end of file diff --git a/src/auto_cf/new_inference.jl b/src/auto_cf/new_inference.jl new file mode 100644 index 0000000000..3d6c192530 --- /dev/null +++ b/src/auto_cf/new_inference.jl @@ -0,0 +1,920 @@ +#simple version of `CC.scan_slot_def_use` +function fill_slot_definition_map(frame) + n_slot = length(frame.src.slotnames) + n_args = length(frame.linfo.specTypes.types) + v = [0 for _ in 1:n_slot] + for (i, stmt) in enumerate(frame.src.code) + stmt isa Expr || continue + stmt.head == :(=) || continue + slot = stmt.args[1] + slot isa Core.SlotNumber || continue + slot.id > n_args || continue + v[slot.id] = v[slot.id] == 0 ? i : v[slot.id] + end + return v +end + +function fill_slot_usage_map(frame) + n_slot = length(frame.src.slotnames) + v = [Set() for _ in 1:n_slot] + for (pos, stmt) in enumerate(frame.src.code) + get_slot(v, stmt, frame, pos) + end + return v +end + +function get_slot(vec, stmt, frame, pos) + if stmt isa Expr + stmt.head == :(=) && return get_slot(vec, stmt.args[2], frame, pos) + for e in stmt.args + get_slot(vec, e, frame, pos) + end + elseif stmt isa Core.SlotNumber + push!(vec[stmt.id], CC.block_for_inst(frame.cfg, pos)) + else + stmt + end +end + + +#an = Analysis(Tree(nothing, [], Ref{Tree}()), nothing, nothing, nothing, nothing) + +function update_tree!(an::Analysis, bb::Int) + for c in an.tree.children + c.node.header_bb == bb || continue + an.pending_tree = c + return true + end + return false +end + +function add_tree!(an::Analysis, tl) + parent = an.tree + t = Tree(tl, [], Ref{Tree}(parent)) + push!(parent.children, t) + return an.pending_tree = t +end + +#Several TCF can end in the same bb +function up_tree!(an::Analysis, bb) + terminal = false + while is_terminal_bb(an.tree, bb) + an.tree = an.tree.parent[] + terminal = true + end + terminal && return nothing + #terminal bb is not always reach: for instance, if bodies are more precisely inferred and nothing change + while !isnothing(an.tree.node) + in_header(bb, an.tree.node) && break + an.tree = an.tree.parent[] + end +end + +function down_tree!(an::Analysis, bb) + for child in an.tree.children + if child.node.header_bb == bb + an.tree = child + break + end + end +end + +function is_terminal_bb(tree::Tree, bb) + isnothing(tree.node) && return false + return tree.node.terminal_bb == bb +end + +Base.in(bb::Int, is::IfStructure) = bb in is.true_bbs || bb in is.false_bbs +Base.in(bb::Int, is::ForStructure) = bb in is.body_bbs + +function in_header(bb::Int, is::IfStructure) + return bb in is.true_bbs || bb in is.false_bbs || bb == is.header_bb +end +function in_header(bb::Int, is::ForStructure) + return bb in is.body_bbs || bb == is.header_bb || bb == is.latch_bb +end + +function in_stack(tree::Tree, bb::Int) + while !isnothing(tree.node) + in_header(bb, tree.node) && return true + tree = tree.parent[] + end + return false +end + +#TODO: don't recompute TCF each time +function add_cf!(an, frame, currbb, currpc, condt) + update_tree!(an, frame.currbb) && return false + + tl = is_a_traced_loop(an, frame.src, frame.cfg, frame.currbb) + if tl !== nothing + add_tree!(an, tl) + return false + end + + tl = is_a_traced_if(an, frame, frame.currbb, condt) + if tl !== nothing + add_tree!(an, tl) + !tl.legalize[] || return false + #legalize if by inserting a call + goto_if_not_index = terminator_index(frame.cfg, frame.currbb) + cond = tl.ssa_cond + ssa = add_instruction!( + frame, + goto_if_not_index - 1, + Expr(:call, GlobalRef(@__MODULE__, :traced_protection), cond), + ) + invalidate_slot_definition_analysis!(an) + (; dest::Int) = frame.src.code[goto_if_not_index + 1]::Core.GotoIfNot #shifted because of the insertion + modify_instruction!(frame, goto_if_not_index + 1, Core.GotoIfNot(ssa, dest)) + tl.legalize[] = true + return true + end + return false +end + +@noinline traced_protection(x::Reactant.TracedRNumber{Bool}) = CC.inferencebarrier(x)::Bool +Reactant.@skip_rewrite_func traced_protection + +@noinline upgrade(x) = Reactant.Ops.constant(x) +@noinline upgrade(x::Union{Reactant.TracedRNumber,Reactant.TracedRArray}) = x + +Reactant.@skip_rewrite_func upgrade +#TODO: need a new traced mode Julia Type Non-concrete -> Traced +upgrade_traced_type(t::Core.Const) = upgrade_traced_type(CC.widenconst(t)) +upgrade_traced_type(t::Type{<:Number}) = Reactant.TracedRNumber{t} +upgrade_traced_type(t::Type{<:Reactant.TracedRNumber}) = t + +in_tcf(an::Analysis) = begin + !isnothing(an.tree.node) +end + +invalidate_slot_definition_analysis!(an) = an.slotanalysis = nothing + +function if_type_passing!(an, frame) + in_tcf(an) || return false + last_cf = an.tree.node + last_cf isa IfStructure || return false + last_cf.header_bb == frame.currbb || return false + !last_cf.legalize[] || return false + cond = last_cf.ssa_cond + goto_if_not_index = terminator_index(frame.cfg, frame.currbb) + ssa = add_instruction!( + frame, + goto_if_not_index - 1, + Expr(:call, GlobalRef(@__MODULE__, :traced_protection), cond), + ) + invalidate_slot_definition_analysis!(an) + + (; dest::Int) = frame.src.code[goto_if_not_index + 1]::Core.GotoIfNot #shifted because of the insertion + modify_instruction!(frame, goto_if_not_index + 1, Core.GotoIfNot(ssa, dest)) + last_cf.legalize[] = true + #update frame + return true +end + +function can_upgrade_loop(an, rt) + in_tcf(an) || return false + last_cf = an.tree.node + last_cf isa ForStructure || return false + last_cf.state == Maybe || return false + is_traced(rt) || return false + return true +end + +# a = expr +# => +# a = upgrade(expr) +# do nothing if expr is already an upgrade call +#TODO: rt suspicious +function apply_slot_upgrade!(frame, pos::Int, rt)::Bool + @warn "upgrade slot $pos $rt" + stmt = frame.src.code[pos] + @assert Base.isexpr(stmt, :(=)) "$stmt" + r = stmt.args[2] + #TODO: iterate can be upgraded to a traced iterate. SSAValue, slots & literal only need stmt change. Others need a new stmt + if Base.isexpr(r, :call) + r.args[1] == GlobalRef(@__MODULE__, :upgrade) && return false + if r.args[1] == GlobalRef(Base, :iterate) + new_type = traced_iterator(rt) + frame.src.code[pos] = Expr( + :(=), + stmt.args[1], + Expr(:call, GlobalRef(Base, :iterate), new_type, r.args[2:end]...), + ) + return true + end + frame.src.code[pos] = stmt.args[2] + add_instruction!( + frame, + pos, + Expr( + :(=), + stmt.args[1], + Expr(:call, GlobalRef(@__MODULE__, :upgrade), Core.SSAValue(pos)), + ); + next_bb=false, + ) + elseif r isa Core.SlotNumber || r isa Core.SSAValue || true #TODO: for expr we must create a new call and the expr + frame.src.code[pos] = Expr( + :(=), stmt.args[1], Expr(:call, GlobalRef(@__MODULE__, :upgrade), stmt.args[2]) + ) + else + error("unsupported slot upgrade $stmt") + end + return true +end + +function current_top_struct(tree) + top_struct = nothing + while !isnothing(tree.node) + top_struct = tree.node + tree = tree.parent[] + end + return top_struct +end + +function get_root(tree) + while !isnothing(tree.node) + tree = tree.parent[] + end + return tree +end + +function get_first_slot_read_stack(frame, tree, slot::Core.SlotNumber, stop::Int) + node = current_top_struct(tree) + start_stmt = frame.cfg.blocks[node.header_bb].stmts.start + for stmt_index in start_stmt:stop + s = frame.src.code[stmt_index] + s isa Core.SlotNumber || continue + s.id == slot.id && return CC.block_for_inst(frame.cfg.index, stmt_index) + end + return nothing +end + +@inline function check_and_upgrade_slot!(an, frame, stmt, rt, currstate) + in_tcf(an) || return (NoUpgrade,) + stmt isa Expr || return (NoUpgrade,) + stmt.head == :(=) || return (NoUpgrade,) + last_cf = an.tree.node + rt_traced = is_traced(rt) + slot = stmt.args[1].id + slot_type::Type = CC.widenconst(currstate[slot].typ) + + #If the stmt is traced: if the slot is traced or not set, don't need to upgrade the slot + #TODO: Nothing suspicions + rt_traced && + (is_traced(slot_type) || slot_type === Union{} || slot_type == Nothing) && + return (NoUpgrade,) + + if last_cf isa IfStructure + (frame.currbb in last_cf.true_bbs || frame.currbb in last_cf.false_bbs) || + return (NoUpgrade,) + #inside a traced_if, slot must be upgraded to a traced type + sa = get_slot_analysis(an, frame)::SlotAnalysis + #TODO: approximation: use liveness analysis to precise promote local slot + # if traced + isempty( + setdiff(sa.slot_bb_usage[slot], union(last_cf.true_bbs, last_cf.false_bbs)) + ) && return (NoUpgrade,) + + #invalidate_slot_definition_analysis!(an) + return if apply_slot_upgrade!(frame, frame.currpc, rt) + (UpgradeLocally,) + else + (NoUpgrade,) + end + + #no need to change frame furthermore + elseif last_cf isa ForStructure + (last_cf.state == Traced || last_cf.state == Upgraded) || return (NoUpgrade,) + if (!rt_traced && is_traced(slot_type)) + return if apply_slot_upgrade!(frame, frame.currpc, rt) + (UpgradeLocally,) + else + (NoUpgrade,) + end + end + sa = get_slot_analysis(an, frame)::SlotAnalysis + slot_definition_pos = sa.slot_stmt_def[slot] + slot_definition_bb = CC.block_for_inst(frame.cfg, slot_definition_pos) + #local slot doesn't need to be upgrade TODO: suspicious + slot_definition_bb in last_cf.body_bbs && return (NoUpgrade,) + if slot_definition_bb == last_cf.header_bb || in_stack(an.tree, slot_definition_bb) + #stack upgrade + #the slot has been upgraded: find read of the slot inside the current traced stack: if any, we must restart the inference from there + return if apply_slot_upgrade!(frame, slot_definition_pos, rt) + (UpgradeDefinition, stmt.args[1]) + else + (NoUpgrade,) + end + else + #global upgrade: add a new slot + new_slot_def_pos = if last_cf.header_bb == 1 + #first block contains argument to slot write: new instructions must be placed after (otherwise all the IR is dead) + new_index = 0 + for i in frame.cfg.blocks[1].stmts + local_stmt = frame.src.code[i] + local_stmt isa Expr && + local_stmt.head == :(=) && + typeof.(frame.src.code[i].args) == + [Core.SlotNumber, Core.SlotNumber] && + continue + local_stmt isa Core.NewvarNode && continue + new_index = i + break + end + new_index + else + frame.cfg.blocks[last_cf.header_bb].stmts.start - 1 + end + #add_slot_change!(frame.src, new_slot_def_pos, slot) + slot = stmt.args[1] + #CodeInfo: Cannot use a slot inside a call + add_instruction!(frame, new_slot_def_pos, slot) + add_instruction!( + frame, + new_slot_def_pos + 1, + Expr( + :(=), + slot, + Expr( + :call, + GlobalRef(@__MODULE__, :upgrade), + Core.SSAValue(new_slot_def_pos + 1), + ), + ), + ) + invalidate_slot_definition_analysis!(an) + return (UpgradeDefinitionGlobal,) + end + return (UpgradeDefinition,) + end +end + +terminator_index(ir::Core.Compiler.IRCode, bb::Int) = terminator_index(ir.cfg, bb) +terminator_index(cfg::CC.CFG, bb::Int) = cfg.blocks[bb].stmts.stop +start_index(ir::CC.IRCode, bb::Int) = start_index(ir.cfg, bb) +start_index(cfg::CC.CFG, bb::Int) = bb == 1 ? 1 : cfg.index[bb - 1] + +#TODO: proper support this by walking the IR +function is_traced_loop_iterator(src::CC.CodeInfo, cfg::CC.CFG, bb::Int) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator_type = src.ssavaluetypes[iterator_index] + return is_traced(iterator_type) +end + +is_traced(t::Type) = parentmodule(t) == Reactant +is_traced(::Core.TypeofBottom) = false +is_traced(t::UnionAll) = is_traced(CC.unwrap_unionall(t)) +is_traced(u::Union) = (|)(is_traced.(Base.uniontypes(u))...) +function is_traced(t::Type{<:Tuple}) + t isa Union && return @invoke is_traced(t::Union) + t = Base.unwrap_unionall(t) + t isa UnionAll && return is_traced(Base.unwrap_unionall(t)) + if typeof(t) == UnionAll #NOTE: strange behavior here: some UnionAll are not handled correctly by the isa check and unwrap_unionall cannot unwrap them. + @error "strange type: $t" + t = t.body + end + return (|)(is_traced.(t.types)...) +end +is_traced(::Type{Tuple{}}) = false +is_traced(t) = false + +#TODO: add support to while loop / general loop +function is_a_traced_loop(an, src::CC.CodeInfo, cfg::CC.CFG, bb_header) + bb_body_first = min(cfg.blocks[bb_header].succs...) + preds::Vector{Int} = cfg.blocks[bb_body_first].preds + (max(preds...) < bb_body_first) && return nothing #No loop + bb_latch = max(preds...) + bb_end = max(cfg.blocks[bb_header].succs...) + bb_body_last = only(cfg.blocks[bb_latch].preds) + #TODO: proper accu and block + return ForStructure( + (), + bb_header, + bb_latch, + bb_end, + Set(bb_body_first:bb_body_last), + is_traced_loop_iterator(src, cfg, bb_header) ? Traced : Maybe, + ) +end + +function bb_owned_branch(domtree, bb::Int)::Set{Int} + bbs = Set(bb) + for c in domtree[bb].children + bbs = union(bbs, bb_owned_branch(domtree, c)) + end + return bbs +end + +function bb_branch(cfg, bb::Int, t_bb::Int)::Set{Int} + bbs = Set() + work = [bb] + while !isempty(work) + c_bb = pop!(work) + (c_bb in bbs || c_bb == t_bb) && continue + push!(bbs, c_bb) + for s in cfg.blocks[c_bb].succs + push!(work, s) + end + end + return bbs +end + +function get_doms(an, frame) + if an.domtree === nothing + an.domtree = CC.construct_domtree(frame.cfg).nodes + an.postdomtree = CC.construct_postdomtree(frame.cfg).nodes + end + return (an.domtree, an.postdomtree) +end + +function get_slot_analysis(an::Analysis, frame)::SlotAnalysis + if an.slotanalysis === nothing + an.slotanalysis = SlotAnalysis( + fill_slot_definition_map(frame), fill_slot_usage_map(frame) + ) + end + return an.slotanalysis +end + +#TODO:remove currbb +function is_a_traced_if(an, frame, currbb, condt) + condt == Reactant.TracedRNumber{Bool} || return nothing + (domtree, postdomtree) = get_doms(an, frame) #compute dominance analysis only when needed + bb = frame.cfg.blocks[currbb] + succs::Vector{Int64} = bb.succs + if_goto_stmt::Core.GotoIfNot = frame.src.code[last(bb.stmts)] + #CodeInfo GotoIfNot.dest is a stmt + first_false_bb = CC.block_for_inst(frame.cfg.index, if_goto_stmt.dest) + first_true_bb = succs[1] == first_false_bb ? succs[2] : succs[1] + last_child = last(domtree[currbb].children) + is_diamond = currbb in postdomtree[last_child].children + final_bb = if is_diamond + last_child + else + if_final_bb = nothing + for (final_bb, nodes) in enumerate(postdomtree) + if currbb in nodes.children + if_final_bb = final_bb + break + end + end + @assert !isnothing(if_final_bb) + if_final_bb + end + true_bbs = bb_branch(frame.cfg, first_true_bb, final_bb) + false_bbs = bb_branch(frame.cfg, first_false_bb, final_bb) + all_owned = bb_owned_branch(domtree, currbb) + true_owned_bbs = intersect(bb_owned_branch(domtree, first_true_bb), all_owned) + false_owned_bbs = intersect(bb_owned_branch(domtree, first_false_bb), all_owned) + return IfStructure( + if_goto_stmt.cond, + currbb, + final_bb, + true_bbs, + false_bbs, + true_owned_bbs, + false_owned_bbs, + Ref{Bool}(false), + ) +end + +#HACK: add a general T to Traced{T} conversion +function traced_iterator(::Type{Union{Nothing,Tuple{T,T}}}) where {T} + is_traced(T) && return T + Tout = Reactant.TracedRNumber{T} + return Union{Nothing,Tuple{Tout,Nothing}} #TODO: replace INT -> Nothing +end + +traced_iterator(t::Type{Tuple{T,T}}) where {T} = traced_iterator(Union{Nothing,t}) + +traced_iterator(t) = begin + if !is_traced(t) + error("fallback $t") + end + t +end + +function get_new_iterator_type(src::CC.CodeInfo, cfg::CC.CFG, bb::Int) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator_type = src.ssavaluetypes[iterator_index] + iterator_type = CC.widenconst(iterator_type) + return traced_iterator(iterator_type) +end + +#TODO: proper check if the iterator exists and replace -3 +function rewrite_iterator(src::CC.CodeInfo, cfg::CC.CFG, bb::Int, new_type::Type) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator = src.code[iterator_index] + iterator_arg = iterator.args[end].args[end] + iterator.args[end] = Expr(:call, GlobalRef(Base, :iterate), new_type, iterator_arg) + return iterator.args[1] +end + +function reset_slot!(state::Union{Nothing,Vector{Core.Compiler.VarState}}, slot::Int) + return isnothing(state) ? state : state[slot] = CC.VarState(Union{}, true) +end + +function reset_slot!( + state::Union{Nothing,Vector{Core.Compiler.VarState}}, slot::Core.SlotNumber +) + return reset_slot!(state, slot.id) +end + +function reset_slot!(states) + for i in eachindex(states) + states[i] = nothing + end +end + +function reset_slot!(states, fs::ForStructure, slot::Core.SlotNumber) + reset_slot!(states[fs.header_bb], slot) + for bb in fs.body_bbs + reset_slot!(states[bb], slot) + end + reset_slot!(states[fs.latch_bb], slot) + return reset_slot!(states[fs.terminal_bb], slot) +end + +#TODO: stack -> branch +function rewrite_loop_stack!(an::Analysis, frame, states, currstate) + (; src::CC.CodeInfo, cfg::CC.CFG) = frame + ct = an.tree + top_loop_tcf = nothing + while !isnothing(ct.node) + node = ct.node + ct = ct.parent[] + node isa ForStructure || continue + node.state == Maybe || continue + #TODO: while loop + new_iterator_type = get_new_iterator_type(src, cfg, node.header_bb) + slot = rewrite_iterator(src, cfg, node.header_bb, new_iterator_type) + last_for_bb = last(sort(collect(node.body_bbs))) + slot = rewrite_iterator(frame.src, frame.cfg, last_for_bb, new_iterator_type) + top_loop_tcf = ct + node.state = Upgraded + end + @assert(!isnothing(top_loop_tcf)) + return top_loop_tcf + #restart type inference from: top_header_rewritten +end + +#Transform an n-terminator bb IR to an 1-terminator bb IR +#TODO: improve algo: remove frame in the loop +function normalize_exit!(frame) + terminator_bbs = findall(isempty.(getfield.(frame.cfg.blocks, :succs))) + length(terminator_bbs) <= 1 && return nothing + new_slot = create_slot!(frame) + add_instruction!(frame, 0, Core.NewvarNode(new_slot)) + + n = length(frame.src.code) + add_instruction!(frame, n, new_slot) + add_instruction!(frame, n + 1, Core.ReturnNode(Core.SSAValue(n + 1))) + push!(frame.bb_vartables, nothing) + offset = 0 + tis = [terminator_index(frame.cfg, tbb) for tbb in terminator_bbs] + for tbb in tis + return_index = offset + tbb + return_ = frame.src.code[return_index] + @assert(return_ isa Core.ReturnNode) + exit_bb_start_pos = terminator_index(frame.cfg, length(frame.cfg.blocks)) + offset += if return_.val isa Core.SSAValue + temp = frame.src.code[return_.val.id] + frame.src.code[return_.val.id] = Expr(:(=), new_slot, temp) + frame.src.code[return_index] = Core.GotoNode(exit_bb_start_pos) + 0 + else + add_instruction!( + frame, return_index, Core.GotoNode(exit_bb_start_pos); next_bb=false + ) + frame.src.code[return_index] = Expr(:(=), new_slot, return_.val) + 1 + end + end + return frame.cfg = CC.compute_basic_blocks(frame.src.code) +end + +#= + CC.typeinf_local(interp::Reactant.ReactantInterpreter, frame::CC.InferenceState) + + Specialize type inference to support control flow aware tracing type inferency + TODO: enable this only for usercode because the new type inference is costly now (several type inference can be needed for a same function) +=# +function typeinf_local(interp::Reactant.ReactantInterpreter, frame::CC.InferenceState) + mod = frame.mod + if is_traced(frame.linfo.specTypes) && + !has_ancestor(mod, Core) && + !has_ancestor(mod, Base) && + !has_ancestor(mod, Reactant) + @info "auto control flow tracing enabled: $(frame.linfo)" + normalize_exit!(frame) + an = Analysis(Tree(nothing, [], Ref{Tree}()), nothing, nothing, nothing, nothing) + typeinf_local_traced(interp, frame, an) + isempty(an.tree) || (interp.meta_data[].traced_tree_map[mi_key(frame.linfo)] = an.tree) + else + @invoke typeinf_local(interp::CC.AbstractInterpreter, frame::CC.InferenceState) + end +end + +function update_context!(an::Analysis, currbb::Int) + isnothing(an.pending_tree) && return nothing + currbb in an.pending_tree.node || return nothing + an.tree = an.pending_tree + return an.pending_tree = nothing +end + +#= + typeinf_local_traced(interp::ReactantInterpreter, frame::CC.InferenceState) + + type infer the `frame` using a Reactant interpreter; notably detect traced control-flow and upgrade traced slot +=# +function typeinf_local_traced( + interp::Reactant.ReactantInterpreter, frame::CC.InferenceState, an::Analysis +) + @assert !CC.is_inferred(frame) + frame.dont_work_on_me = true # mark that this function is currently on the stack + W = frame.ip + ssavaluetypes = frame.ssavaluetypes + bbs = frame.cfg.blocks + nbbs = length(bbs) + 𝕃ᡒ = CC.typeinf_lattice(interp) + + currbb = frame.currbb + if currbb != 1 + currbb = frame.currbb = CC._bits_findnext(W.bits, 1)::Int # next basic block + end + + states = frame.bb_vartables + init_state = CC.copy(states[currbb]) + currstate = CC.copy(states[currbb]::CC.VarTable) + + debug_stmt = false + while currbb <= nbbs + CC.delete!(W, currbb) + bbstart = first(bbs[currbb].stmts) + bbend = last(bbs[currbb].stmts) + currpc = bbstart - 1 + update_context!(an, currbb) + up_tree!(an, currbb) + @warn frame.linfo currbb an.tree.node get_root(an.tree) + while currpc < bbend + currpc += 1 + frame.currpc = currpc + CC.empty_backedges!(frame, currpc) + stmt = frame.src.code[currpc] + # If we're at the end of the basic block ... + if currpc == bbend + # Handle control flow + if isa(stmt, Core.GotoNode) + succs = bbs[currbb].succs + @assert length(succs) == 1 + nextbb = succs[1] + ssavaluetypes[currpc] = Any + CC.handle_control_backedge!(interp, frame, currpc, stmt.label) + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + @goto branch + elseif isa(stmt, Core.GotoIfNot) + condx = stmt.cond + condxslot = CC.ssa_def_slot(condx, frame) + condt = CC.abstract_eval_value(interp, condx, currstate, frame) + + if add_cf!(an, frame, currbb, currpc, condt) + @goto reset_inference + end + + if condt === CC.Bottom + ssavaluetypes[currpc] = CC.Bottom + CC.empty!(frame.pclimitations) + @goto find_next_bb + end + orig_condt = condt + if !(isa(condt, Core.Const) || isa(condt, CC.Conditional)) && + isa(condxslot, Core.SlotNumber) + # if this non-`Conditional` object is a slot, we form and propagate + # the conditional constraint on it + condt = CC.Conditional( + condxslot, Core.Const(true), Core.Const(false) + ) + end + condval = CC.maybe_extract_const_bool(condt) + nothrow = (condval !== nothing) || CC.:(βŠ‘)(𝕃ᡒ, orig_condt, Bool) + if nothrow + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + else + CC.update_exc_bestguess!(interp, TypeError, frame) + CC.propagate_to_error_handler!(currstate, frame, 𝕃ᡒ) + CC.merge_effects!(interp, frame, CC.EFFECTS_THROWS) + end + + if !CC.isempty(frame.pclimitations) + # we can't model the possible effect of control + # dependencies on the return + # directly to all the return values (unless we error first) + condval isa Bool || + CC.union!(frame.limitations, frame.pclimitations) + empty!(frame.pclimitations) + end + ssavaluetypes[currpc] = Any + if condval === true + @goto fallthrough + else + if !nothrow && !CC.hasintersect(CC.widenconst(orig_condt), Bool) + ssavaluetypes[currpc] = CC.Bottom + @goto find_next_bb + end + + succs = bbs[currbb].succs + if length(succs) == 1 + @assert condval === false || (stmt.dest === currpc + 1) + nextbb = succs[1] + @goto branch + end + @assert length(succs) == 2 + truebb = currbb + 1 + falsebb = succs[1] == truebb ? succs[2] : succs[1] + if condval === false + nextbb = falsebb + CC.handle_control_backedge!(interp, frame, currpc, stmt.dest) + @goto branch + end + # We continue with the true branch, but process the false + # branch here. + if isa(condt, CC.Conditional) + else_change = CC.conditional_change( + 𝕃ᡒ, currstate, condt.elsetype, condt.slot + ) + if else_change !== nothing + false_vartable = CC.stoverwrite1!( + copy(currstate), else_change + ) + else + false_vartable = currstate + end + changed = CC.update_bbstate!(𝕃ᡒ, frame, falsebb, false_vartable) + then_change = CC.conditional_change( + 𝕃ᡒ, currstate, condt.thentype, condt.slot + ) + + if then_change !== nothing + CC.stoverwrite1!(currstate, then_change) + end + else + changed = CC.update_bbstate!(𝕃ᡒ, frame, falsebb, currstate) + end + if changed + CC.handle_control_backedge!(interp, frame, currpc, stmt.dest) + CC.push!(W, falsebb) + end + @goto fallthrough + end + elseif isa(stmt, Core.ReturnNode) + rt = CC.abstract_eval_value(interp, stmt.val, currstate, frame) + if CC.update_bestguess!(interp, frame, currstate, rt) + CC.update_cycle_worklists!( + frame + ) do caller::CC.InferenceState, caller_pc::Int + # no reason to revisit if that call-site doesn't affect the final result + return caller.ssavaluetypes[caller_pc] !== Any + end + end + ssavaluetypes[frame.currpc] = Any + @goto find_next_bb + elseif isa(stmt, Core.EnterNode) + ssavaluetypes[currpc] = Any + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + if isdefined(stmt, :scope) + scopet = CC.abstract_eval_value( + interp, stmt.scope, currstate, frame + ) + handler = frame.handlers[frame.handler_at[frame.currpc + 1][1]] + @assert handler.scopet !== nothing + if !CC.:(βŠ‘)(𝕃ᡒ, scopet, handler.scopet) + handler.scopet = CC.tmerge(𝕃ᡒ, scopet, handler.scopet) + if isdefined(handler, :scope_uses) + for bb in handler.scope_uses + push!(W, bb) + end + end + end + end + @goto fallthrough + elseif CC.isexpr(stmt, :leave) + ssavaluetypes[currpc] = Any + @goto fallthrough + end + # Fall through terminator - treat as regular stmt + end + + # Process non control-flow statements + (; changes, rt, exct) = CC.abstract_eval_basic_statement( + interp, stmt, currstate, frame + ) + if !CC.has_curr_ssaflag(frame, CC.IR_FLAG_NOTHROW) + if exct !== Union{} + CC.update_exc_bestguess!(interp, exct, frame) + # TODO: assert that these conditions match. For now, we assume the `nothrow` flag + # to be correct, but allow the exct to be an over-approximation. + end + CC.propagate_to_error_handler!(currstate, frame, 𝕃ᡒ) + end + + #upgrade maybe for loop here: eagerly restart type inference if we detect an traced type + #NOTE: must be placed before CC.Bottom check: in a traced context, an iterator with invalid arguments still should be upgraded + #@info stmt an.tree + upgrade_result = check_and_upgrade_slot!(an, frame, stmt, rt, currstate) + slot_state = first(upgrade_result) + if slot_state === UpgradeDefinition #Slot Upgrade ... + bbs = frame.cfg.blocks + bbend = last(bbs[currbb].stmts) + @goto reset_inference + continue + elseif slot_state === UpgradeDefinitionGlobal + @goto reset_inference + elseif slot_state === UpgradeLocally + @goto reset_inference + end + + if rt === CC.Bottom + ssavaluetypes[currpc] = CC.Bottom + # Special case: Bottom-typed PhiNodes do not error (but must also be unused) + if isa(stmt, Core.PhiNode) + continue + end + @goto find_next_bb + end + + #Slot upgrade must be placed before any slot/ssa table change + if changes !== nothing + CC.stoverwrite1!(currstate, changes) + end + if rt === nothing + ssavaluetypes[currpc] = Any + continue + end + + if can_upgrade_loop(an, rt) + rewrite_loop_stack!(an, frame, states, currstate) + @goto reset_inference + end + + # IMPORTANT: set the type + CC.record_ssa_assign!(𝕃ᡒ, currpc, rt, frame) + end # while currpc < bbend + + # Case 1: Fallthrough termination + begin + @label fallthrough + nextbb = currbb + 1 + end + + # Case 2: Directly branch to a different BB + begin + @label branch + if CC.update_bbstate!(𝕃ᡒ, frame, nextbb, currstate) + CC.push!(W, nextbb) + end + end + + # Case 3: Control flow ended along the current path (converged, return or throw) + begin + @label find_next_bb + currbb = frame.currbb = CC._bits_findnext(W.bits, 1)::Int # next basic block + currbb == -1 && break # the working set is empty + currbb > nbbs && break + nexttable = states[currbb] + if nexttable === nothing + CC.init_vartable!(currstate, frame) + else + CC.stoverwrite!(currstate, nexttable) + end + end + + begin + continue + @label reset_inference + CC.empty!(W) + currbb = 1 + frame.currbb = 1 + currpc = 1 + frame.currpc = 1 + reset_slot!(states) + an.tree = get_root(an.tree) + states[1] = copy(init_state) + currstate = copy(init_state) + for i in eachindex(frame.ssavaluetypes) + frame.ssavaluetypes[i] = CC.NotFound() + end + bbs = frame.cfg.blocks + nbbs = length(bbs) + ssavaluetypes = frame.ssavaluetypes + end + end # while currbb <= nbbs + @lk an + frame.dont_work_on_me = false + return nothing +end diff --git a/src/auto_cf/utils_bench.jl b/src/auto_cf/utils_bench.jl new file mode 100644 index 0000000000..ec0bd9bd58 --- /dev/null +++ b/src/auto_cf/utils_bench.jl @@ -0,0 +1,156 @@ +function init_mlir() + ctx = Reactant.MLIR.IR.Context() + @ccall Reactant.MLIR.API.mlir_c.RegisterDialects(ctx::Reactant.MLIR.API.MlirContext)::Cvoid +end + +get_traced_object(::Type{Reactant.TracedRNumber{T}}) where T = Reactant.Ops.constant(rand(T)) + +get_traced_object(::Type{Reactant.TracedRArray{T,N}}) where {T,N} = Reactant.Ops.constant(rand(T, [1 for i in 1:N]...)) + +get_traced_object(t) = begin + @error t + rand(t) +end + + +#= + analysis_reassign_block_id!(an::Analysis, ir::Core.IRCode, src::Core.CodeInfo) + slot2reg can change type infered CodeInfo CFG by removing non-reachable block, + ControlFlow analysis use blocks information and must be shifted + +=# +function analysis_reassign_block_id!(an::Analysis, ir::CC.IRCode, src::CC.CodeInfo) + cfg = CC.compute_basic_blocks(src.code) + length(ir.cfg.blocks) == length(cfg.blocks) && return false + @info "rewrite analysis blocks" + new_block_map = [] + i = 0 + for block in cfg.blocks + unreacheable_block = all(x->src.ssavaluetypes[x] === Union{}, block.stmts) + i = unreacheable_block ? i : i + 1 + push!(new_block_map, i) + end + @info new_block_map + function reassign_tree!(s::Set{Int}) + n = [new_block_map[i] for i in s] + empty!(s) + push!(s, n...) + end + + function reassign_tree!(is::IfStructure) + is.header_bb = new_block_map[is.header_bb] + is.terminal_bb = new_block_map[is.terminal_bb] + reassign_tree!(is.true_bbs) + reassign_tree!(is.false_bbs) + reassign_tree!(is.owned_true_bbs) + reassign_tree!(is.owned_false_bbs) + end + + function reassign_tree!(fs::ForStructure) + fs.header_bb = new_block_map[fs.header_bb] + fs.latch_bb = new_block_map[fs.latch_bb] + fs.terminal_bb = new_block_map[fs.terminal_bb] + reassign_tree!(fs.body_bbs) + end + + function reassign_tree!(t::Tree) + isnothing(t.node) || reassign_tree!(t.node) + for c in t.children + reassign_tree!(c) + end + end + reassign_tree!(an.tree) + @error an.tree + return true +end + +function test(f) + m = methods(f)[1] + types = m.sig.parameters[2:end] + mi = Base.method_instance(f, types) + @lk mi + world = Base.get_world_counter() + interp = Reactant.ReactantInterpreter(; world) + resul = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + src = CC.retrieve_code_info(resul.linfo, world) + osrc = CC.copy(src) + @lk osrc src + frame = CC.InferenceState(resul, src, :no, interp) + CC.typeinf(interp, frame) + opt = CC.OptimizationState(frame, interp) + ir0 = CC.convert_to_ircode(opt.src, opt) + ir = CC.slot2reg(ir0, opt.src, opt) + analysis_reassign_block_id!(an, ir, src) + ir = CC.compact!(ir) + bir = CC.copy(ir) + @lk bir + ir_final = control_flow_transform!(an, ir) + + modu = Reactant.MLIR.IR.Module() + @lk modu + #init_caches() + Reactant.MLIR.IR.activate!(modu) + Reactant.MLIR.IR.activate!(Reactant.MLIR.IR.body(modu)) + + ttypes = collect(types)[is_traced.(types)] + @lk types ttypes + + + to_mlir(::Type{Reactant.TracedRArray{T,N}}) where {T,N} = Reactant.MLIR.IR.TensorType(repeat([4096], N), Reactant.MLIR.IR.Type(T)) + to_mlir(x) = Reactant.Ops.mlir_type(x) + f_args = to_mlir.(ttypes) + + temporal_func = Reactant.MLIR.Dialects.func.func_(; + sym_name="main_", + function_type=Reactant.MLIR.IR.FunctionType(f_args, []), + body=Reactant.MLIR.IR.Region(), + sym_visibility=Reactant.MLIR.IR.Attribute("private"), + ) + + main = Reactant.MLIR.IR.Block(f_args, [Reactant.MLIR.IR.Location() for _ in f_args]) + push!(Reactant.MLIR.IR.region(temporal_func, 1), main) + Reactant.Ops.activate_constant_context!(main) + Reactant.MLIR.IR.activate!(main) + + args = [] + i = 1 + for tt in types + if !is_traced(tt) + push!(args, rand(tt)) + continue + end + + arg = if ttypes[i] <: Reactant.TracedRArray + ttypes[i]((), nothing, repeat([4096], ttypes[i].parameters[2])) + else + ttypes[i]((), nothing) + end + Reactant.TracedUtils.set_mlir_data!(arg, Reactant.MLIR.IR.argument(main, i)) + push!(args, arg) + i += 1 + end + + + #A = Reactant.Ops.constant(rand(Int,2,2)); + #B = Reactant.Ops.constant(rand(Int,2,2)); + r = juliair_to_mlir(ir_final, args...)[2] + Reactant.Ops.return_(r...) + Reactant.Ops.deactivate_constant_context!(main) + Reactant.MLIR.IR.deactivate!(main) + + + func = Reactant.MLIR.Dialects.func.func_(; + sym_name="main", + function_type=Reactant.MLIR.IR.FunctionType(f_args, Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type.(r)...]), + body=Reactant.MLIR.IR.Region(), + sym_visibility=Reactant.MLIR.IR.Attribute("private"), + ) + + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(func, 1), Reactant.MLIR.IR.region(temporal_func, 1)) + + Reactant.MLIR.API.mlirOperationDestroy(temporal_func.operation) + + Reactant.MLIR.IR.verifyall(Reactant.MLIR.IR.Operation(modu); debug=true) ||Β error("fail") + modu +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index c7cb254946..68c8637161 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,883 +18,11 @@ function apply(f::F, args...; kwargs...) where {F} return f(args...; kwargs...) end -function call_with_reactant end - -function maybe_argextype(@nospecialize(x), src) - return try - Core.Compiler.argextype(x, src) - catch err - !(err isa Core.Compiler.InvalidIRError) && rethrow() - nothing - end -end +include("JIT.jl") # Defined in KernelAbstractions Ext function ka_with_reactant end -""" - Reactant.REDUB_ARGUMENTS_NAME - -The variable name bound to `call_with_reactant`'s tuple of arguments in its -`@generated` method definition. - -This binding can be used to manually reference/destructure `call_with_reactants` arguments - -This is required because user arguments could have a name which clashes with whatever name we choose for -our argument. Thus we gensym to create it. - -This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 -""" -const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") - -function throw_method_error(argtys) - throw(MethodError(argtys[1], argtys[2:end])) -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Union{Nothing,Core.MethodTable}, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = ccall( - :jl_gf_invoke_lookup_worlds, - Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, - mt, - world, - min_world, - max_world, - ) - return res -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Core.Compiler.InternalMethodTable, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = lookup_world(sig, mt.world, nothing, min_world, max_world) - return res -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Core.Compiler.OverlayMethodTable, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) - if res !== nothing - return res - else - return lookup_world(sig, mt.world, nothing, min_world, max_world) - end -end - -function has_ancestor(query::Module, target::Module) - query == target && return true - while true - next = parentmodule(query) - next == target && return true - next == query && return false - query = next - end -end - -const __skip_rewrite_func_set_lock = ReentrantLock() -const __skip_rewrite_func_set = Set([ - # Avoid the 1.10 stackoverflow - typeof(Base.typed_hvcat), - typeof(Base.hvcat), - typeof(Core.Compiler.concrete_eval_eligible), - typeof(Core.Compiler.typeinf_type), - typeof(Core.Compiler.typeinf_ext), - # TODO: perhaps problematic calls in `traced_call` - # should be moved to TracedUtils.jl: - typeof(ReactantCore.traced_call), - typeof(ReactantCore.is_traced), - # Perf optimization - typeof(Base.typemax), - typeof(Base.typemin), - typeof(Base.getproperty), - typeof(Base.vect), - typeof(Base.eltype), - typeof(Base.argtail), - typeof(Base.identity), - typeof(Base.print), - typeof(Base.println), - typeof(Base.show), - typeof(Base.show_delim_array), - typeof(Base.sprint), - typeof(Adapt.adapt_structure), - typeof(Core.is_top_bit_set), - typeof(Base.setindex_widen_up_to), - typeof(Base.typejoin), - typeof(Base.argtype_decl), - typeof(Base.arg_decl_parts), - typeof(Base.StackTraces.show_spec_sig), - typeof(Core.Compiler.return_type), - typeof(Core.throw_inexacterror), - typeof(Base.throw_boundserror), - typeof(Base._shrink), - typeof(Base._shrink!), - typeof(Base.ht_keyindex), - typeof(Base.checkindex), - typeof(Base.to_index), - @static( - if VERSION >= v"1.11.0" - typeof(Base.memoryref) - end - ), - typeof(materialize_traced_array), -]) - -""" - @skip_rewrite_func f - -Mark function `f` so that Reactant's IR rewrite mechanism will skip it. -This can improve compilation time if it's safe to assume that no call inside `f` -will need a `@reactant_overlay` method. - -!!! info - Note that this marks the whole function, not a specific method with a type - signature. - -!!! warning - The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope - too. - -See also: [`@skip_rewrite_type`](@ref) -""" -macro skip_rewrite_func(fname) - quote - @lock $(Reactant.__skip_rewrite_func_set_lock) push!( - $(Reactant.__skip_rewrite_func_set), typeof($(esc(fname))) - ) - end -end - -const __skip_rewrite_type_constructor_list_lock = ReentrantLock() -const __skip_rewrite_type_constructor_list = [ - # Don't rewrite Val - Type{Base.Val}, - # Don't rewrite exception constructors - Type{<:Core.Exception}, - # Don't rewrite traced constructors - Type{<:TracedRArray}, - Type{<:TracedRNumber}, - Type{MLIR.IR.Location}, - Type{MLIR.IR.Block}, -] - -""" - @skip_rewrite_type MyStruct - @skip_rewrite_type Type{<:MyStruct} - -Mark the construct function of `MyStruct` so that Reactant's IR rewrite mechanism -will skip it. It does the same as [`@skip_rewrite_func`](@ref) but for type -constructors. - -If you want to mark the set of constructors over it's type parameters or over its -abstract type, you should use then the `Type{<:MyStruct}` syntax. - -!!! warning - The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope - too. -""" -macro skip_rewrite_type(typ) - typ = if Base.isexpr(typ, :curly) && typ.args[1] === :Type - typ - else - Expr(:curly, :Type, typ) - end - return quote - @lock $(Reactant.__skip_rewrite_type_constructor_list_lock) push!( - $(Reactant.__skip_rewrite_type_constructor_list), $(esc(typ)) - ) - end -end - -function should_rewrite_call(@nospecialize(ft)) - # Don't rewrite builtin or intrinsics - if ft <: Core.IntrinsicFunction || ft <: Core.Builtin - return false - end - if ft <: Core.Function - if hasfield(typeof(ft), :name) && - hasfield(typeof(ft.name), :name) && - isdefined(ft.name, :name) - namestr = String(ft.name.name) - if startswith(namestr, "##(overlay (. Reactant (inert REACTANT_METHOD_TABLE)") - return false - end - end - - # We need this for closures to work - if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) - mod = ft.name.module - # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions - if has_ancestor(mod, Ops) || - has_ancestor(mod, TracedUtils) || - has_ancestor(mod, MLIR) - return false - end - if string(mod) == "CUDA" - if ft.name.name == Symbol("#launch_configuration") - return false - end - if ft.name.name == Symbol("cudaconvert") - return false - end - end - end - end - - # `ft isa Type` is for performance as it avoids checking against all the list, but can be removed if problematic - if ft isa Type && any(t -> ft <: t, __skip_rewrite_type_constructor_list) - return false - end - - if ft in __skip_rewrite_func_set - return false - end - - # Default assume all functions need to be reactant-ified - return true -end - -# by default, same as `should_rewrite_call` -function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) - # TODO how can we extend `@skip_rewrite` to methods? - if ft <: typeof(repeat) && (args == Tuple{String,Int64} || args == Tuple{Char,Int64}) - return false - end - return should_rewrite_call(ft) -end - -# Avoid recursively interpreting into methods we define explicitly -# as overloads, which we assume should handle the entirety of the -# translation (and if not they can use call_in_reactant). -function is_reactant_method(mi::Core.MethodInstance) - meth = mi.def - if !isdefined(meth, :external_mt) - return false - end - mt = meth.external_mt - return mt === REACTANT_METHOD_TABLE -end - -struct MustThrowError end - -@generated function applyiterate_with_reactant( - iteratefn, applyfn, args::Vararg{Any,N} -) where {N} - if iteratefn != typeof(Base.iterate) - return quote - error("Unhandled apply_iterate with iteratefn=$iteratefn") - end - end - newargs = Vector{Expr}(undef, N) - for i in 1:N - @inbounds newargs[i] = :(args[$i]...) - end - quote - Base.@_inline_meta - call_with_reactant(applyfn, $(newargs...)) - end -end - -@generated function applyiterate_with_reactant( - mt::MustThrowError, iteratefn, applyfn, args::Vararg{Any,N} -) where {N} - @assert iteratefn == typeof(Base.iterate) - newargs = Vector{Expr}(undef, N) - for i in 1:N - @inbounds newargs[i] = :(args[$i]...) - end - quote - Base.@_inline_meta - call_with_reactant(mt, applyfn, $(newargs...)) - end -end - -function certain_error() - throw( - AssertionError( - "The inferred code was guaranteed to throw this error. And yet, it didn't. So here we are...", - ), - ) -end - -function rewrite_inst(inst, ir, interp, RT, guaranteed_error) - if Meta.isexpr(inst, :call) - # Even if type unstable we do not want (or need) to replace intrinsic - # calls or builtins with our version. - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) - if ft == typeof(Core.kwcall) - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - end - if ft == typeof(Core._apply_iterate) - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - if Base.invokelatest(should_rewrite_call, ft) - if RT === Union{} - rep = Expr( - :call, - applyiterate_with_reactant, - MustThrowError(), - inst.args[2:end]..., - ) - return true, rep, Union{} - else - rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) - return true, rep, Any - end - end - elseif Base.invokelatest(should_rewrite_call, ft) - if RT === Union{} - rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) - return true, rep, Union{} - else - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep, Any - end - end - end - if Meta.isexpr(inst, :invoke) - omi = inst.args[1]::Core.MethodInstance - sig = omi.specTypes - ft = sig.parameters[1] - argsig = sig.parameters[2:end] - if ft == typeof(Core.kwcall) - ft = sig.parameters[3] - argsig = sig.parameters[4:end] - end - argsig = Core.apply_type(Core.Tuple, argsig...) - if Base.invokelatest(should_rewrite_invoke, ft, argsig) && !is_reactant_method(omi) - method = omi.def::Core.Method - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - # RT = Any - - if !method.isva || !Base.isvarargtype(sig.parameters[end]) - if RT === Union{} - sig2 = Tuple{ - typeof(call_with_reactant),MustThrowError,sig.parameters... - } - else - sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} - end - else - vartup = inst.args[end] - ns = Type[] - eT = sig.parameters[end].T - for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) - push!(ns, eT) - end - if RT === Union{} - sig2 = Tuple{ - typeof(call_with_reactant), - MustThrowError, - sig.parameters[1:(end - 1)]..., - ns..., - } - else - sig2 = Tuple{ - typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... - } - end - end - - lookup_result = lookup_world( - sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world - ) - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - n_method_args = method.nargs - if RT === Union{} - rep = Expr( - :invoke, mi, call_with_reactant, MustThrowError(), inst.args[2:end]... - ) - return true, rep, Union{} - else - rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) - return true, rep, Any - end - end - end - if isa(inst, Core.ReturnNode) && (!isdefined(inst, :val) || guaranteed_error) - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - sig2 = Tuple{typeof(certain_error)} - - lookup_result = lookup_world( - sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world - ) - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - rep = Expr(:invoke, mi, certain_error) - return true, rep, Union{} - end - return false, inst, RT -end - -const oc_capture_vec = Vector{Any}() - -# Caching is both good to reducing compile times and necessary to work around julia bugs -# in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 -function make_oc_dict( - @nospecialize(oc_captures::Dict{FT,Core.OpaqueClosure}), - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f::FT) -)::Core.OpaqueClosure where {FT} - key = f - if haskey(oc_captures, key) - oc = oc_captures[key] - oc - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[key] = ores - return ores - end -end - -function make_oc_ref( - oc_captures::Base.RefValue{Core.OpaqueClosure}, - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f) -)::Core.OpaqueClosure - if Base.isassigned(oc_captures) - return oc_captures[] - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[] = ores - return ores - end -end - -function safe_print(name, x) - return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) -end - -const DEBUG_INTERP = Ref(false) - -# Rewrite type unstable calls to recurse into call_with_reactant to ensure -# they continue to use our interpreter. Reset the derived return type -# to Any if our interpreter would change the return type of any result. -# Also rewrite invoke (type stable call) to be :call, since otherwise apparently -# screws up type inference after this (TODO this should be fixed). -function rewrite_insts!(ir, interp, guaranteed_error) - any_changed = false - for (i, inst) in enumerate(ir.stmts) - # Explicitly skip any code which returns Union{} so that we throw the error - # instead of risking a segfault - RT = inst[:type] - @static if VERSION < v"1.11" - changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], RT, :type) - end - end - return ir, any_changed -end - -function rewrite_argnumbers_by_one!(ir) - # Add one dummy argument at the beginning - pushfirst!(ir.argtypes, Nothing) - - # Re-write all references to existing arguments to their new index (N + 1) - for idx in 1:length(ir.stmts) - urs = Core.Compiler.userefs(ir.stmts[idx][:inst]) - changed = false - it = Core.Compiler.iterate(urs) - while it !== nothing - (ur, next) = it - old = Core.Compiler.getindex(ur) - if old isa Core.Argument - # Replace the Argument(n) with Argument(n + 1) - Core.Compiler.setindex!(ur, Core.Argument(old.n + 1)) - changed = true - end - it = Core.Compiler.iterate(urs, next) - end - if changed - @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :inst) - else - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :stmt) - end - end - end - - return nothing -end - -# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter -# In particular this entails two pieces: -# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance -# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are -# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia -# using a custom interpreter in type unstable code. -# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function call_with_reactant_generator( - world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) -) - @nospecialize - args = redub_arguments - if DEBUG_INTERP[] - safe_print("args", args) - end - - stub = Core.GeneratedFunctionStub( - identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() - ) - - fn = args[1] - sig = Tuple{args...} - - guaranteed_error = false - if fn === MustThrowError - guaranteed_error = true - fn = args[2] - sig = Tuple{args[2:end]...} - end - - # look up the method match - builtin_error = - :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) - - if fn <: Core.Builtin - return stub(world, source, builtin_error) - end - - if guaranteed_error - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world) - )) - else - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) - )) - end - - interp = ReactantInterpreter(; world) - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - lookup_result = lookup_world( - sig, world, Core.Compiler.method_table(interp), min_world, max_world - ) - - overdubbed_code = Any[] - overdubbed_codelocs = Int32[] - - # No method could be found (including in our method table), bail with an error - if lookup_result === nothing - return stub(world, source, method_error) - end - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - method = mi.def - - @static if VERSION < v"1.11" - # For older Julia versions, we vendor in some of the code to prevent - # having to build the MethodInstance twice. - result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) - frame = CC.InferenceState(result, :no, interp) - @assert !isnothing(frame) - CC.typeinf(interp, frame) - ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) - rt = CC.widenconst(CC.ignorelimited(result.result)) - else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) - end - - if guaranteed_error - if rt !== Union{} - safe_print("Inconsistent guaranteed error IR", ir) - end - rt = Union{} - end - - if DEBUG_INTERP[] - safe_print("ir", ir) - end - - mi = mi::Core.MethodInstance - - if !( - is_reactant_method(mi) || ( - mi.def.sig isa DataType && - !should_rewrite_invoke( - mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...} - ) - ) - ) || guaranteed_error - ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error) - end - - rewrite_argnumbers_by_one!(ir) - - src = ccall(:jl_new_code_info_uninit, Ref{Core.CodeInfo}, ()) - src.slotnames = fill(:none, length(ir.argtypes) + 1) - src.slotflags = fill(zero(UInt8), length(ir.argtypes)) - src.slottypes = copy(ir.argtypes) - src.rettype = rt - src = CC.ir_to_codeinf!(src, ir) - - if DEBUG_INTERP[] - safe_print("src", src) - end - - # prepare a new code info - code_info = copy(src) - static_params = match.sparams - signature = sig - - # propagate edge metadata, this method is invalidated if the original function we are calling - # is invalidated - code_info.edges = Core.MethodInstance[mi] - code_info.min_world = min_world[] - code_info.max_world = max_world[] - - # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, - # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] - code_info.slotflags = UInt8[0x00, 0x00] - n_prepended_slots = 2 - overdub_args_slot = Core.SlotNumber(n_prepended_slots) - - # For the sake of convenience, the rest of this pass will translate `code_info`'s fields - # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at - # the end of the pass, we'll reset `code_info` fields accordingly. - overdubbed_code = Any[] - overdubbed_codelocs = Int32[] - function push_inst!(inst) - push!(overdubbed_code, inst) - push!(overdubbed_codelocs, code_info.codelocs[1]) - return Core.SSAValue(length(overdubbed_code)) - end - # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention - # required by the base method. - - # destructure the generated argument slots into the overdubbed method's argument slots. - - offset = 1 - fn_args = Any[] - n_method_args = method.nargs - n_actual_args = length(redub_arguments) - if guaranteed_error - offset += 1 - n_actual_args -= 1 - end - - tys = [] - - iter_args = n_actual_args - if method.isva - iter_args = min(n_actual_args, n_method_args - 1) - end - - for i in 1:iter_args - actual_argument = Expr( - :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - ) - arg = push_inst!(actual_argument) - offset += 1 - push!(fn_args, arg) - push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # If `method` is a varargs method, we have to restructure the original method call's - # trailing arguments into a tuple and assign that tuple to the expected argument slot. - if method.isva - trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) - for i in n_method_args:n_actual_args - arg = push_inst!( - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) - ) - push!(trailing_arguments.args, arg) - offset += 1 - end - - push!(fn_args, push_inst!(trailing_arguments)) - push!( - tys, - Tuple{ - redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]..., - }, - ) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # ocva = method.isva - - ocva = false # method.isva - - ocnargs = Int(method.nargs) - # octup = Tuple{mi.specTypes.parameters[2:end]...} - # octup = Tuple{method.sig.parameters[2:end]...} - octup = Tuple{tys[1:end]...} - ocva = false - - # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right - # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require taking the function argument. We can work around the latter - # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - - dict, make_oc = (Base.Ref{Core.OpaqueClosure}(), make_oc_ref) - - push!(oc_capture_vec, dict) - - oc = if false && Base.issingletontype(fn) - res = Core._call_in_world_total( - world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance - )::Core.OpaqueClosure - else - farg = fn_args[1] - farg = nothing - rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) - push_inst!(rep) - Core.SSAValue(length(overdubbed_code)) - end - - push_inst!(Expr(:call, oc, fn_args[1:end]...)) - - ocres = Core.SSAValue(length(overdubbed_code)) - - if DEBUG_INTERP[] - push_inst!(Expr(:call, safe_print, "ocres", ocres)) - end - - push_inst!(Core.ReturnNode(ocres)) - - #=== set `code_info`/`reflection` fields accordingly ===# - - if code_info.method_for_inference_limit_heuristics === nothing - code_info.method_for_inference_limit_heuristics = method - end - - code_info.code = overdubbed_code - code_info.codelocs = overdubbed_codelocs - code_info.ssavaluetypes = length(overdubbed_code) - code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - - if DEBUG_INTERP[] - safe_print("code_info", code_info) - end - - return code_info -end - -@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) - $(Expr(:meta, :generated_only)) - return $(Expr(:meta, :generated, call_with_reactant_generator)) -end - @static if isdefined(Core, :BFloat16) nmantissa(::Type{Core.BFloat16}) = 7 end