From b5387bb69cc9f8113ba25ea8083f93e73689083b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 12 Feb 2025 17:36:33 +0100 Subject: [PATCH 1/4] init if --- src/ext.jl | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 20 +++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 src/ext.jl diff --git a/src/ext.jl b/src/ext.jl new file mode 100644 index 0000000000..bf8ea685b9 --- /dev/null +++ b/src/ext.jl @@ -0,0 +1,87 @@ +CC = Core.Compiler +ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{ + typeof(Reactant.set_reactant_abi) +} +EnzymeInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter + +shift_off(s, _) = s +shift_off(s::Core.SSAValue, new_index::Vector) = Core.SSAValue(new_index[s.id]) + +#add a conversion to Bool before a lowered if +goto_if_not_protection(src::Core.CodeInfo) = begin + new_index = [] + offset = 0 + for (i, t) in enumerate(typeof.(src.code)) + t == Core.GotoIfNot && (offset += 2) + push!(new_index, i + offset) + end + + nc = [] + ncl = [] + for (i, c) in enumerate(src.code) + v = nothing + if c isa Core.GotoIfNot + push!(nc, GlobalRef(Main, :convert)) + push!(nc, Expr(:call, (Core.SSAValue(new_index[i] - 2), GlobalRef(Main, :Bool), shift_off(c.cond, new_index))...)) + append!(ncl, [src.codelocs[i] for _ in 1:2]) + v = Core.GotoIfNot(Core.SSAValue(new_index[i] - 1), new_index[c.dest]) + elseif c isa Core.GotoNode + v = Core.GotoNode(new_index[c.label]) + elseif c isa Core.ReturnNode + v = Core.ReturnNode(shift_off(c.val, new_index)) + elseif c isa Expr + v = Expr(c.head, (shift_off(a, new_index) for a in c.args)...) + else + v = c + end + push!(nc, v) + push!(ncl, src.codelocs[i]) + end + new = copy(src) + new.code = nc + new.codelocs = ncl + for _ in 1:offset + push!(new.ssaflags, 0x00000000) + end + new.ssavaluetypes = src.ssavaluetypes + offset + return new +end + +vec = [] +vec2 = [] +function CC.inlining_policy( + interp::ReactantInter, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), + stmt_flag::UInt32, +) + #typeof(src) in [CC.IRCode, Core.CodeInfo] || return; + #push!(vec, (CC.copy(src), info)) + #push!(vec2, stacktrace()) + #=info isa CC.ConstCallInfo && (info = info.call) + push!(vec, info) + if info isa MethodMatchInfo + mm::Core.MethodMatch = first(info.results.matches) + m::Method = mm.method + if m.name == :convert && m.sig isa DataType + if m.sig.types[3] == Reactant.TracedRNumber{Bool} + return true + end + end + end + #push!(vec2, info) + =# + return nothing + @invoke CC.inlining_policy( + interp::EnzymeInter, src, info::CC.CallInfo, stmt_flag::UInt32 + ) +end + +#=vec2 = [] +CC.finish!(ji::ReactantInter, caller::CC.InferenceState) = begin + res = @invoke CC.finish!(ji::EnzymeInter, caller::CC.InferenceState) + push!(vec2, res) +end +=# + + diff --git a/src/utils.jl b/src/utils.jl index 36c4ca7fac..8148aaaa90 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -475,6 +475,11 @@ function rewrite_insts!(ir, interp, guaranteed_error) return ir, any_changed end +include("ext.jl") + +global dico = Dict() +global dico2 = Dict() + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -562,7 +567,17 @@ function call_with_reactant_generator( ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) rt = CC.widenconst(CC.ignorelimited(result.result)) else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + world = CC.get_inference_world(interp) + src = CC.retrieve_code_info(result.linfo, world) + dico2[mi]=(CC.copy(src), goto_if_not_protection(src)) + #src = goto_if_not_protection(src) + CC.maybe_validate_code(result.linfo, src, "lowered") + frame = CC.InferenceState(result, src, :no, interp) + CC.typeinf(interp, frame) + opt = CC.OptimizationState(frame, interp) + ir = CC.run_passes_ipo_safe(opt.src, opt, result) + rt = CC.widenconst(CC.ignorelimited(result.result)) end if guaranteed_error @@ -768,3 +783,6 @@ end $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end + + + From b9a476c354c2717929cf066ba523b4b6c9a0c279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 12 Feb 2025 17:57:41 +0100 Subject: [PATCH 2/4] test 1 --- src/utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8148aaaa90..a98c911cfd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -570,8 +570,9 @@ function call_with_reactant_generator( result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) world = CC.get_inference_world(interp) src = CC.retrieve_code_info(result.linfo, world) - dico2[mi]=(CC.copy(src), goto_if_not_protection(src)) - #src = goto_if_not_protection(src) + #dico2[mi]=(CC.copy(src), goto_if_not_protection(src)) + @error src + src = goto_if_not_protection(src) CC.maybe_validate_code(result.linfo, src, "lowered") frame = CC.InferenceState(result, src, :no, interp) CC.typeinf(interp, frame) From 45a4df01ef562709ec55349767dbfe20f06958f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 12 Feb 2025 19:44:19 +0100 Subject: [PATCH 3/4] test 2 --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index a98c911cfd..baf6739fbd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -573,6 +573,7 @@ function call_with_reactant_generator( #dico2[mi]=(CC.copy(src), goto_if_not_protection(src)) @error src src = goto_if_not_protection(src) + @error src CC.maybe_validate_code(result.linfo, src, "lowered") frame = CC.InferenceState(result, src, :no, interp) CC.typeinf(interp, frame) From 761f7b0b51340e6101ea481eaee605aa3210f3f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 13 Feb 2025 18:53:40 +0100 Subject: [PATCH 4/4] fix --- src/ext.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/ext.jl b/src/ext.jl index bf8ea685b9..a64e31b5f6 100644 --- a/src/ext.jl +++ b/src/ext.jl @@ -7,6 +7,13 @@ EnzymeInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter shift_off(s, _) = s shift_off(s::Core.SSAValue, new_index::Vector) = Core.SSAValue(new_index[s.id]) + +apply(c::Expr, new_index) = begin + return Expr(c.head, (shift_off(apply(a, new_index), new_index) for a in c.args)...) +end + +apply(c, _new_index) = c + #add a conversion to Bool before a lowered if goto_if_not_protection(src::Core.CodeInfo) = begin new_index = [] @@ -30,7 +37,7 @@ goto_if_not_protection(src::Core.CodeInfo) = begin elseif c isa Core.ReturnNode v = Core.ReturnNode(shift_off(c.val, new_index)) elseif c isa Expr - v = Expr(c.head, (shift_off(a, new_index) for a in c.args)...) + v = apply(c, new_index) else v = c end