diff --git a/src/ext.jl b/src/ext.jl new file mode 100644 index 0000000000..a64e31b5f6 --- /dev/null +++ b/src/ext.jl @@ -0,0 +1,94 @@ +CC = Core.Compiler +ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{ + typeof(Reactant.set_reactant_abi) +} +EnzymeInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter + +shift_off(s, _) = s +shift_off(s::Core.SSAValue, new_index::Vector) = Core.SSAValue(new_index[s.id]) + + +apply(c::Expr, new_index) = begin + return Expr(c.head, (shift_off(apply(a, new_index), new_index) for a in c.args)...) +end + +apply(c, _new_index) = c + +#add a conversion to Bool before a lowered if +goto_if_not_protection(src::Core.CodeInfo) = begin + new_index = [] + offset = 0 + for (i, t) in enumerate(typeof.(src.code)) + t == Core.GotoIfNot && (offset += 2) + push!(new_index, i + offset) + end + + nc = [] + ncl = [] + for (i, c) in enumerate(src.code) + v = nothing + if c isa Core.GotoIfNot + push!(nc, GlobalRef(Main, :convert)) + push!(nc, Expr(:call, (Core.SSAValue(new_index[i] - 2), GlobalRef(Main, :Bool), shift_off(c.cond, new_index))...)) + append!(ncl, [src.codelocs[i] for _ in 1:2]) + v = Core.GotoIfNot(Core.SSAValue(new_index[i] - 1), new_index[c.dest]) + elseif c isa Core.GotoNode + v = Core.GotoNode(new_index[c.label]) + elseif c isa Core.ReturnNode + v = Core.ReturnNode(shift_off(c.val, new_index)) + elseif c isa Expr + v = apply(c, new_index) + else + v = c + end + push!(nc, v) + push!(ncl, src.codelocs[i]) + end + new = copy(src) + new.code = nc + new.codelocs = ncl + for _ in 1:offset + push!(new.ssaflags, 0x00000000) + end + new.ssavaluetypes = src.ssavaluetypes + offset + return new +end + +vec = [] +vec2 = [] +function CC.inlining_policy( + interp::ReactantInter, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), + stmt_flag::UInt32, +) + #typeof(src) in [CC.IRCode, Core.CodeInfo] || return; + #push!(vec, (CC.copy(src), info)) + #push!(vec2, stacktrace()) + #=info isa CC.ConstCallInfo && (info = info.call) + push!(vec, info) + if info isa MethodMatchInfo + mm::Core.MethodMatch = first(info.results.matches) + m::Method = mm.method + if m.name == :convert && m.sig isa DataType + if m.sig.types[3] == Reactant.TracedRNumber{Bool} + return true + end + end + end + #push!(vec2, info) + =# + return nothing + @invoke CC.inlining_policy( + interp::EnzymeInter, src, info::CC.CallInfo, stmt_flag::UInt32 + ) +end + +#=vec2 = [] +CC.finish!(ji::ReactantInter, caller::CC.InferenceState) = begin + res = @invoke CC.finish!(ji::EnzymeInter, caller::CC.InferenceState) + push!(vec2, res) +end +=# + + diff --git a/src/utils.jl b/src/utils.jl index 36c4ca7fac..baf6739fbd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -475,6 +475,11 @@ function rewrite_insts!(ir, interp, guaranteed_error) return ir, any_changed end +include("ext.jl") + +global dico = Dict() +global dico2 = Dict() + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -562,7 +567,19 @@ function call_with_reactant_generator( ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) rt = CC.widenconst(CC.ignorelimited(result.result)) else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + world = CC.get_inference_world(interp) + src = CC.retrieve_code_info(result.linfo, world) + #dico2[mi]=(CC.copy(src), goto_if_not_protection(src)) + @error src + src = goto_if_not_protection(src) + @error src + CC.maybe_validate_code(result.linfo, src, "lowered") + frame = CC.InferenceState(result, src, :no, interp) + CC.typeinf(interp, frame) + opt = CC.OptimizationState(frame, interp) + ir = CC.run_passes_ipo_safe(opt.src, opt, result) + rt = CC.widenconst(CC.ignorelimited(result.result)) end if guaranteed_error @@ -768,3 +785,6 @@ end $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end + + +