Skip to content

Commit 70ef0dd

Browse files
committed
access enzyme_context through gutils
1 parent cf1f006 commit 70ef0dd

File tree

5 files changed

+32
-5
lines changed

5 files changed

+32
-5
lines changed

src/Enzyme.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Base.convert(::Type{API.CDerivativeMode}, ::ForwardMode) = API.DEM_ForwardMode
133133
function guess_activity end
134134

135135
mutable struct EnzymeContext
136+
world::UInt64
136137
end
137138

138139
include("logic.jl")

src/api.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,15 @@ EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall
579579
EnzymeGetShadowType(width, T) =
580580
ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64, LLVMTypeRef), width, T)
581581

582+
function EnzymeGradientUtilsGetExternalContext(gutils)
583+
ccall(
584+
(:EnzymeGradientUtilsGetExternalContext, libEnzyme),
585+
Ptr{Cvoid},
586+
(EnzymeGradientUtilsRef,),
587+
gutils,
588+
)
589+
end
590+
582591
EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall(
583592
(:EnzymeGradientUtilsReplaceAWithB, libEnzyme),
584593
Cvoid,

src/compiler.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ include("compiler/utils.jl")
223223

224224
include("compiler/orcv2.jl")
225225

226-
include("gradientutils.jl")
227-
226+
import .Enzyme: GradientUtils, call_samefunc_with_inverted_bundles!,
227+
get_width, get_mode, get_runtime_activity,
228+
get_strong_zero, get_shadow_type, get_uncacheable,
229+
erase_with_placeholder, is_constant_value, is_constant_inst,
230+
new_from_original, lookup_value, invert_pointer, debug_from_orig!,
231+
add_reverse_block!, set_reverse_block!, enzyme_context
228232

229233
# Julia function to LLVM stem and arity
230234
const cmplx_known_ops =
@@ -2504,7 +2508,7 @@ function enzyme!(
25042508
convert(API.CDIFFE_TYPE, rt)
25052509
end
25062510

2507-
enzyme_context = EnzymeContext()
2511+
enzyme_context = EnzymeContext(job.world)
25082512
GC.@preserve enzyme_context begin
25092513
LLVM.@dispose logic = Logic(enzyme_context) begin
25102514

src/gradientutils.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ erase_with_placeholder(
6666
orig::LLVM.Instruction,
6767
erase::Bool = true,
6868
) = API.EnzymeGradientUtilsEraseWithPlaceholder(gutils, inst, orig, erase)
69+
6970
is_constant_value(gutils::GradientUtils, val::LLVM.Value) =
7071
API.EnzymeGradientUtilsIsConstantValue(gutils, val) != 0
7172

@@ -96,4 +97,16 @@ end
9697

9798
function set_reverse_block!(gutils::GradientUtils, block::LLVM.BasicBlock)
9899
return LLVM.BasicBlock(API.EnzymeGradientUtilsSetReverseBlock(gutils, block))
99-
end
100+
end
101+
102+
function enzyme_context(gutils::GradientUtils)
103+
ptr = API.EnzymeGradientUtilsGetExternalContext(gutils)
104+
@assert ptr != C_NULL
105+
return unsafe_pointer_to_objref(ptr)::EnzymeContext
106+
end
107+
108+
function enzyme_gutils_context(gutils::API.EnzymeGradientUtilsRef)
109+
ptr = API.EnzymeGradientUtilsGetExternalContext(gutils)
110+
@assert ptr != C_NULL
111+
return unsafe_pointer_to_objref(ptr)::EnzymeContext
112+
end

src/logic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function enzyme_context(logic::Logic)
1818
return logic.ctx::EnzymeContext
1919
end
2020

21-
function enzyme_context(logic::API.EnzymeLogicRef)
21+
function enzyme_logic_context(logic::API.EnzymeLogicRef)
2222
ptr = API.LogicGetExternalContext(logic)
2323
@assert ptr != C_NULL
2424
return unsafe_pointer_to_objref(ptr)::EnzymeContext

0 commit comments

Comments
 (0)