From e8c07ba457d5cfc6970bd623cc8932ad6751ec7a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 2 Oct 2025 17:27:52 +0200 Subject: [PATCH 1/8] access enzyme_context through gutils --- src/Enzyme.jl | 1 + src/api.jl | 9 +++++++++ src/compiler.jl | 10 +++++++--- src/gradientutils.jl | 15 ++++++++++++++- src/logic.jl | 2 +- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 5494830e61..102ecd5a43 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -133,6 +133,7 @@ Base.convert(::Type{API.CDerivativeMode}, ::ForwardMode) = API.DEM_ForwardMode function guess_activity end mutable struct EnzymeContext + world::UInt64 end include("logic.jl") diff --git a/src/api.jl b/src/api.jl index df1a166056..51a4fd31d8 100644 --- a/src/api.jl +++ b/src/api.jl @@ -579,6 +579,15 @@ EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64, LLVMTypeRef), width, T) +function EnzymeGradientUtilsGetExternalContext(gutils) + ccall( + (:EnzymeGradientUtilsGetExternalContext, libEnzyme), + Ptr{Cvoid}, + (EnzymeGradientUtilsRef,), + gutils, + ) +end + EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall( (:EnzymeGradientUtilsReplaceAWithB, libEnzyme), Cvoid, diff --git a/src/compiler.jl b/src/compiler.jl index 96b1c01c5d..4cfe5a1eb7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -223,8 +223,12 @@ include("compiler/utils.jl") include("compiler/orcv2.jl") -include("gradientutils.jl") - +import .Enzyme: GradientUtils, call_samefunc_with_inverted_bundles!, + get_width, get_mode, get_runtime_activity, + get_strong_zero, get_shadow_type, get_uncacheable, + erase_with_placeholder, is_constant_value, is_constant_inst, + new_from_original, lookup_value, invert_pointer, debug_from_orig!, + add_reverse_block!, set_reverse_block!, enzyme_context # Julia function to LLVM stem and arity const cmplx_known_ops = @@ -2507,7 +2511,7 @@ function enzyme!( convert(API.CDIFFE_TYPE, rt) end - enzyme_context = EnzymeContext() + enzyme_context = EnzymeContext(job.world) GC.@preserve enzyme_context begin LLVM.@dispose logic = Logic(enzyme_context) begin diff --git a/src/gradientutils.jl b/src/gradientutils.jl index b9b27028e8..4848adb56e 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -66,6 +66,7 @@ erase_with_placeholder( orig::LLVM.Instruction, erase::Bool = true, ) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase) + is_constant_value(gutils::GradientUtils, val::LLVM.Value) = API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0 @@ -96,4 +97,16 @@ end function set_reverse_block!(gutils::GradientUtils, block::LLVM.BasicBlock) return LLVM.BasicBlock(API.EnzymeGradientUtilsSetReverseBlock(gutils, block)) -end \ No newline at end of file +end + +function enzyme_context(gutils::GradientUtils) + ptr = API.EnzymeGradientUtilsGetExternalContext(gutils) + @assert ptr != C_NULL + return unsafe_pointer_to_objref(ptr)::EnzymeContext +end + +function enzyme_gutils_context(gutils::API.EnzymeGradientUtilsRef) + ptr = API.EnzymeGradientUtilsGetExternalContext(gutils) + @assert ptr != C_NULL + return unsafe_pointer_to_objref(ptr)::EnzymeContext +end diff --git a/src/logic.jl b/src/logic.jl index 58a45a9cca..24750e50e5 100644 --- a/src/logic.jl +++ b/src/logic.jl @@ -18,7 +18,7 @@ function enzyme_context(logic::Logic) return logic.ctx::EnzymeContext end -function enzyme_context(logic::API.EnzymeLogicRef) +function enzyme_logic_context(logic::API.EnzymeLogicRef) ptr = API.LogicGetExternalContext(logic) @assert ptr != C_NULL return unsafe_pointer_to_objref(ptr)::EnzymeContext From 0246aed1cb56cf6cc42dcd695b09825ce034480a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 2 Oct 2025 17:32:08 +0200 Subject: [PATCH 2/8] assert check the world we encode in enzyme_extrace_world --- src/compiler.jl | 2 ++ src/errors.jl | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4cfe5a1eb7..7acaaa64c0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1083,6 +1083,7 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV end world = enzyme_extract_world(f) + @assert world == interp.world if expectLen != length(parameters(f)) continue @@ -1668,6 +1669,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie mode == API.DEM_ReverseModeCombined fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) + @assert world == enzyme_gutils_context(gutils).world if !guaranteed_nonactive(Ty, world) B = LLVM.IRBuilder() position!(B, V) diff --git a/src/errors.jl b/src/errors.jl index b5dd4a0ac9..74932eab10 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -475,6 +475,7 @@ function julia_error( ) #=error=# world = enzyme_extract_world(f) end + @assert world == enzyme_gutils_context(gutils).world throw(IllegalTypeAnalysisException(msg, mi, world, sval, ir, bt)) elseif errtype == API.ET_NoType @assert B != C_NULL @@ -550,6 +551,7 @@ function julia_error( illegal = false created = LLVM.Instruction[] world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B)))) + @assert world == enzyme_context(gutils).world width = get_width(gutils) function make_batched(@nospecialize(cur::LLVM.Value), B::LLVM.IRBuilder)::LLVM.Value if width == 1 @@ -944,7 +946,7 @@ end end end - mi = nothing + mi = nothing world = nothing if isa(val, LLVM.Instruction) @@ -962,6 +964,7 @@ end ) #=error=# world = enzyme_extract_world(f) end + @assert world == enzyme_gutils_context(gutils).world mode = Enzyme.API.DEM_ReverseModeCombined if mi !== nothing From 577957e74302b8dfd17e8d51a7194faee2bd0a81 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 2 Oct 2025 20:08:24 +0200 Subject: [PATCH 3/8] fixup! access enzyme_context through gutils --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7acaaa64c0..194078b4aa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -228,7 +228,7 @@ import .Enzyme: GradientUtils, call_samefunc_with_inverted_bundles!, get_strong_zero, get_shadow_type, get_uncacheable, erase_with_placeholder, is_constant_value, is_constant_inst, new_from_original, lookup_value, invert_pointer, debug_from_orig!, - add_reverse_block!, set_reverse_block!, enzyme_context + add_reverse_block!, set_reverse_block!, enzyme_context, enzyme_gutils_context # Julia function to LLVM stem and arity const cmplx_known_ops = From 03feb2e85d856e6040923974de5b33d389c01068 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 3 Oct 2025 18:08:04 +0200 Subject: [PATCH 4/8] fixup! fixup! access enzyme_context through gutils --- src/compiler.jl | 2 +- src/errors.jl | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 194078b4aa..d7a7bc9ced 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1669,7 +1669,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie mode == API.DEM_ReverseModeCombined fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) - @assert world == enzyme_gutils_context(gutils).world + @assert world == enzyme_context(gutils).world if !guaranteed_nonactive(Ty, world) B = LLVM.IRBuilder() position!(B, V) diff --git a/src/errors.jl b/src/errors.jl index 74932eab10..e37d82eb97 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -475,7 +475,8 @@ function julia_error( ) #=error=# world = enzyme_extract_world(f) end - @assert world == enzyme_gutils_context(gutils).world + # TODO: get world from TypeAnalyzer + # @assert world == enzyme_gutils_context(gutils).world throw(IllegalTypeAnalysisException(msg, mi, world, sval, ir, bt)) elseif errtype == API.ET_NoType @assert B != C_NULL @@ -964,7 +965,9 @@ end ) #=error=# world = enzyme_extract_world(f) end - @assert world == enzyme_gutils_context(gutils).world + # what is data? + # Can we get world here? + # @assert world == enzyme_context(gutils).world mode = Enzyme.API.DEM_ReverseModeCombined if mi !== nothing From ae19b0964bc88dc9267abc3081df8db1e8a542d6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 3 Oct 2025 19:38:13 +0200 Subject: [PATCH 5/8] add some more uses of enzyme_context to access the world --- src/rules/activityrules.jl | 1 + src/rules/customrules.jl | 6 ++++++ src/rules/llvmrules.jl | 1 + src/rules/parallelrules.jl | 6 ++++++ src/rules/typeunstablerules.jl | 3 +++ 5 files changed, 17 insertions(+) diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index 1ce499d91c..ff7cc66c9d 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -31,6 +31,7 @@ function julia_activity_rule(f::LLVM.Function) return end world = enzyme_extract_world(f) + # TODO: Access to gutils # TODO fix the attributor inlining such that this can assert always true if expectLen != length(parameters(f)) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 2a86748722..df1533f5dc 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -159,6 +159,7 @@ function enzyme_custom_setup_args( ofn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(ofn) + @assert world == enzyme_context(gutils).world for arg in jlargs @assert arg.cc != RemovedParam @@ -468,6 +469,7 @@ function enzyme_custom_setup_ret( mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(LLVM.parent(orig))) + @assert world == enzyme_context(gutils).world needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) @@ -591,6 +593,7 @@ end curent_bb = position(B) fn = LLVM.parent(curent_bb) world = enzyme_extract_world(fn) + @assert world == enzyme_context(gutils).world llvmf = nested_codegen!(mode, mod, fmi, world) @@ -826,6 +829,7 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) + @assert world == enzyme_context(gutils).world C = EnzymeRules.RevConfig{ Bool(needsPrimal), @@ -939,6 +943,7 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) + @assert world == enzyme_context(gutils).world @safe_debug "Trying to apply custom forward rule" TT isKWCall functy = if isKWCall @@ -1049,6 +1054,7 @@ function enzyme_custom_common_rev( curent_bb = position(B) fn = LLVM.parent(curent_bb) world = enzyme_extract_world(fn) + @assert world == enzyme_context(gutils).world mode = get_mode(gutils) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 12cb93df5d..2cbd83d3fb 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1932,6 +1932,7 @@ end fn = LLVM.parent(LLVM.parent(orig)) world = enzyme_extract_world(fn) + @assert world == enzyme_context(gutils).world if !guaranteed_nonactive(ET, world) emit_error(B, orig, "Enzyme: element type $ET of generic_memory_copyto is potentially active ($reg) and not presently supported") end diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index f99ec02a78..35ffecede0 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -227,6 +227,7 @@ end modifiedBetween = (mode != API.DEM_ForwardMode, false) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world pfuncT = funcT @@ -550,6 +551,7 @@ end tt = Tuple{thunkTy,dfuncT,Bool} mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) @@ -594,6 +596,7 @@ end } mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) push!(function_attributes(entry), EnumAttribute("alwaysinline")) @@ -627,6 +630,7 @@ end @register_rev function threadsfor_rev(B, orig, gutils, tape) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return end @@ -675,6 +679,7 @@ end mode = get_mode(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world ops = collect(operands(orig)) @@ -731,6 +736,7 @@ end ModifiedBetween = (uncacheable[1] != 0,) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world ops = collect(operands(orig)) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 05e7a59c95..96779f2139 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -442,6 +442,7 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) width = get_width(gutils) world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world @assert is_constant_value(gutils, origops[offset]) icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] @@ -932,6 +933,7 @@ end else @assert legal world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world if !guaranteed_nonactive(TT, world) unsafe_store!(tapeR, shadowres.ref) end @@ -1034,6 +1036,7 @@ end if legal @assert legal world = enzyme_extract_world(LLVM.parent(position(B))) + @assert world == enzyme_context(gutils).world torun = !guaranteed_nonactive(TT, world) else torun = true From ee6caadd8a578f7d71845fe8465c9b833090b017 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 3 Oct 2025 19:41:46 +0200 Subject: [PATCH 6/8] add world as an argument to julia_activity_rule --- src/compiler.jl | 2 +- src/rules/activityrules.jl | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d7a7bc9ced..8574c2743e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -629,7 +629,7 @@ end name = meth.name jlmod = meth.module - julia_activity_rule(llvmfn) + julia_activity_rule(llvmfn, world) if has_custom_rule handleCustom( state, diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index ff7cc66c9d..e9deb1e770 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -1,5 +1,5 @@ -function julia_activity_rule(f::LLVM.Function) +function julia_activity_rule(f::LLVM.Function, world) if startswith(LLVM.name(f), "japi3") || startswith(LLVM.name(f), "japi1") return end @@ -30,8 +30,6 @@ function julia_activity_rule(f::LLVM.Function) if mi.specTypes.parameters[end] === Vararg{Any} return end - world = enzyme_extract_world(f) - # TODO: Access to gutils # TODO fix the attributor inlining such that this can assert always true if expectLen != length(parameters(f)) From 108da13713e23c299798ad5cffc6f629e99151ba Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 3 Oct 2025 19:43:49 +0200 Subject: [PATCH 7/8] add todo to julia_error --- src/errors.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/errors.jl b/src/errors.jl index e37d82eb97..11137bd670 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -965,6 +965,7 @@ end ) #=error=# world = enzyme_extract_world(f) end + # TODO(vchuravy) # what is data? # Can we get world here? # @assert world == enzyme_context(gutils).world From b623edffdef933681fcea1a08fd89c1a6a306095 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 7 Oct 2025 21:23:19 +0200 Subject: [PATCH 8/8] pass context to enzyme! --- src/compiler.jl | 33 ++++++++++++++++++++------------- src/rules/customrules.jl | 12 +++++------- src/rules/parallelrules.jl | 15 ++++++--------- src/rules/typeunstablerules.jl | 2 +- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8574c2743e..58d2ff4d9b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -486,12 +486,13 @@ include("llvm/transforms.jl") include("llvm/passes.jl") include("typeutils/make_zero.jl") -function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) - funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world) - nested_codegen!(mode, mod, funcspec, world) +function nested_codegen!(ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type)) + funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, ctx.world) + nested_codegen!(ctx, mode, mod, funcspec) end function prepare_llvm(interp, mod::LLVM.Module, job, meta) + # TODO: remove enzymejl_world for f in functions(mod) attributes = function_attributes(f) push!(attributes, StringAttribute("enzymejl_world", string(job.world))) @@ -1234,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false) const DumpPostNestedOpt = Ref(false) function nested_codegen!( + ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, - world::UInt, ) + world = ctx.world # TODO: Put a cache here index on `mod` and f->tt @@ -1254,6 +1256,7 @@ function nested_codegen!( GPUCompiler.prepare_job!(job) otherMod, meta = GPUCompiler.emit_llvm(job) + # TODO: interp should be cached since it contains internal caches interp = GPUCompiler.get_interpreter(job) prepare_llvm(interp, otherMod, job, meta) @@ -2398,6 +2401,7 @@ const DumpPostEnzyme = Ref(false) const DumpPostWrap = Ref(false) function enzyme!( + enzyme_context::EnzymeContext, job::CompilerJob, interp, mod::LLVM.Module, @@ -2513,7 +2517,6 @@ function enzyme!( convert(API.CDIFFE_TYPE, rt) end - enzyme_context = EnzymeContext(job.world) GC.@preserve enzyme_context begin LLVM.@dispose logic = Logic(enzyme_context) begin @@ -2583,6 +2586,7 @@ function enzyme!( if wrap augmented_primalf = create_abi_wrapper( + enzyme_context, augmented_primalf, TT, rt, @@ -2592,7 +2596,6 @@ function enzyme!( width, returnPrimal, shadow_init, - world, interp, runtimeActivity, ) @@ -2625,6 +2628,7 @@ function enzyme!( ) #=atomicAdd=# if wrap adjointf = create_abi_wrapper( + enzyme_context, adjointf, TT, rt, @@ -2634,7 +2638,6 @@ function enzyme!( width, false, shadow_init, - world, interp, runtimeActivity ) #=returnPrimal=# @@ -2666,6 +2669,7 @@ function enzyme!( augmented_primalf = nothing if wrap adjointf = create_abi_wrapper( + enzyme_context, adjointf, TT, rt, @@ -2675,7 +2679,6 @@ function enzyme!( width, returnPrimal, shadow_init, - world, interp, runtimeActivity ) @@ -2711,6 +2714,7 @@ function enzyme!( if wrap pf = adjointf adjointf = create_abi_wrapper( + enzyme_context, adjointf, TT, rt, @@ -2720,7 +2724,6 @@ function enzyme!( width, returnPrimal, shadow_init, - world, interp, runtimeActivity ) @@ -2792,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp) end function create_abi_wrapper( + ctx::EnzymeContext, enzymefn::LLVM.Function, @nospecialize(TT::Type), @nospecialize(rettype::Type), @@ -2801,10 +2805,10 @@ function create_abi_wrapper( width::Int, returnPrimal::Bool, shadow_init::Bool, - world::UInt, interp, runtime_activity::Bool ) + world = ctx.world is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal needs_tape = Mode == API.DEM_ReverseModeGradient @@ -3087,6 +3091,7 @@ function create_abi_wrapper( realparms = LLVM.Value[] i = 1 + # TODO(vchuravy): remove for attr in collect(function_attributes(enzymefn)) if kind(attr) == "enzymejl_world" push!(function_attributes(llvm_f), attr) @@ -3231,7 +3236,7 @@ function create_abi_wrapper( elseif T <: BatchDuplicatedFunc Func = get_func(T) funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world) - llvmf = nested_codegen!(Mode, mod, funcspec, world) + llvmf = nested_codegen!(ctx, Mode, mod, funcspec) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) Func_RT = return_type(interp, funcspec) @assert Func_RT == NTuple{width,T′} @@ -5102,6 +5107,7 @@ end end end + ctx = EnzymeContext(job.world) if params.run_enzyme # Generate the adjoint memcpy_alloca_to_loadstore(mod) @@ -5109,8 +5115,9 @@ end API.EnzymeDetectReadonlyOrThrow(mod) adjointf, augmented_primalf, TapeType = enzyme!( + ctx, job, - interp, + interp, mod, primalf, TT, @@ -5209,7 +5216,7 @@ end fname = String(name) * pf if haskey(functions(mod), fname) funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job.world) - llvmf = nested_codegen!(mode, mod, funcspec, job.world) + llvmf = nested_codegen!(ctx, mode, mod, funcspec) push!(function_attributes(llvmf), StringAttribute("implements", fname)) end end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index df1533f5dc..0937a75870 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -592,10 +592,8 @@ end curent_bb = position(B) fn = LLVM.parent(curent_bb) - world = enzyme_extract_world(fn) - @assert world == enzyme_context(gutils).world - llvmf = nested_codegen!(mode, mod, fmi, world) + llvmf = nested_codegen!(enzyme_context(gutils), mode, mod, fmi) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -1053,8 +1051,8 @@ function enzyme_custom_common_rev( curent_bb = position(B) fn = LLVM.parent(curent_bb) - world = enzyme_extract_world(fn) - @assert world == enzyme_context(gutils).world + ctx = enzyme_context(gutils) + world = ctx.world mode = get_mode(gutils) @@ -1115,7 +1113,7 @@ function enzyme_custom_common_rev( applicablefn = true if forward - llvmf = nested_codegen!(mode, mod, ami, world) + llvmf = nested_codegen!(ctx, mode, mod, ami) @assert llvmf !== nothing rev_RT = nothing else @@ -1157,7 +1155,7 @@ function enzyme_custom_common_rev( rmi = rmi::Core.MethodInstance rev_RT = rev_RT::Type - llvmf = nested_codegen!(mode, mod, rmi, world) + llvmf = nested_codegen!(ctx, mode, mod, rmi) end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 35ffecede0..f86ac3ce87 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -550,9 +550,8 @@ end tt = Tuple{thunkTy,dfuncT,Bool} mode = get_mode(gutils) - world = enzyme_extract_world(LLVM.parent(position(B))) - @assert world == enzyme_context(gutils).world - entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world) + ctx = enzyme_context(gutils) + entry = nested_codegen!(ctx, mode, mod, runtime_pfor_fwd, tt) push!(function_attributes(entry), EnumAttribute("alwaysinline")) pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid})) @@ -595,9 +594,8 @@ end Bool, } mode = get_mode(gutils) - world = enzyme_extract_world(LLVM.parent(position(B))) - @assert world == enzyme_context(gutils).world - entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world) + ctx = enzyme_context(gutils) + entry = nested_codegen!(ctx, mode, mod, runtime_pfor_augfwd, tt) push!(function_attributes(entry), EnumAttribute("alwaysinline")) pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid})) @@ -629,8 +627,6 @@ end @register_rev function threadsfor_rev(B, orig, gutils, tape) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - world = enzyme_extract_world(LLVM.parent(position(B))) - @assert world == enzyme_context(gutils).world if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return end @@ -653,7 +649,8 @@ end Bool, } mode = get_mode(gutils) - entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world) + ctx = enzyme_context(gutils) + entry = nested_codegen!(ctx, mode, mod, runtime_pfor_rev, tt) push!(function_attributes(entry), EnumAttribute("alwaysinline")) pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid})) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 96779f2139..c459414b5b 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1036,7 +1036,7 @@ end if legal @assert legal world = enzyme_extract_world(LLVM.parent(position(B))) - @assert world == enzyme_context(gutils).world + @assert world == enzyme_context(gutils).world torun = !guaranteed_nonactive(TT, world) else torun = true