Skip to content

Commit aaa5a36

Browse files
authored
Add interface for stashing context in Logic (#2624)
1 parent ed2b72d commit aaa5a36

File tree

5 files changed

+36
-4
lines changed

5 files changed

+36
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ CEnum = "0.4, 0.5"
4545
ChainRulesCore = "1"
4646
DynamicPPL = "0.35, 0.36, 0.37"
4747
EnzymeCore = "0.8.14"
48-
Enzyme_jll = "0.0.201"
48+
Enzyme_jll = "0.0.202"
4949
GPUArraysCore = "0.1.6, 0.2"
5050
GPUCompiler = "1.6.2"
5151
LLVM = "9.1"

src/Enzyme.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ Base.convert(::Type{API.CDerivativeMode}, ::ForwardMode) = API.DEM_ForwardMode
132132

133133
function guess_activity end
134134

135+
mutable struct EnzymeContext
136+
end
137+
135138
include("logic.jl")
136139
include("analyses/type.jl")
137140
include("typetree.jl")

src/api.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,14 @@ function FreeLogic(logic)
996996
ccall((:FreeEnzymeLogic, libEnzyme), Cvoid, (EnzymeLogicRef,), logic)
997997
end
998998

999+
function LogicSetExternalContext(logic, ctx)
1000+
ccall((:EnzymeLogicSetExternalContext, libEnzyme), Cvoid, (EnzymeLogicRef, Ptr{Cvoid}), logic, ctx)
1001+
end
1002+
1003+
function LogicGetExternalContext(logic)
1004+
ccall((:EnzymeLogicGetExternalContext, libEnzyme), Ptr{Cvoid}, (EnzymeLogicRef,), logic)
1005+
end
1006+
9991007
function EnzymeExtractReturnInfo(ret, data, existed)
10001008
@assert length(data) == length(existed)
10011009
ccall(

src/compiler.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Enzyme:
1313
guess_activity,
1414
eltype,
1515
API,
16+
EnzymeContext,
1617
TypeTree,
1718
typetree,
1819
TypeTreeTable,
@@ -2497,7 +2498,10 @@ function enzyme!(
24972498
convert(API.CDIFFE_TYPE, rt)
24982499
end
24992500

2500-
logic = Logic()
2501+
enzyme_context = EnzymeContext()
2502+
GC.@preserve enzyme_context begin
2503+
LLVM.@dispose logic = Logic(enzyme_context) begin
2504+
25012505
TA = TypeAnalysis(logic)
25022506

25032507
retTT = if !isa(actualRetType, Union) &&
@@ -2753,7 +2757,10 @@ function enzyme!(
27532757
if DumpPostEnzyme[]
27542758
API.EnzymeDumpModuleRef(mod.ref)
27552759
end
2760+
27562761
return adjointf, augmented_primalf, TapeType
2762+
end # @dispose logic
2763+
end # GC.preserve enzyme_context
27572764
end
27582765

27592766
function get_subprogram(f::LLVM.Function)

src/logic.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@ import LLVM: refcheck
22

33
LLVM.@checked struct Logic
44
ref::API.EnzymeLogicRef
5-
function Logic()
5+
ctx::EnzymeContext
6+
function Logic(ctx::EnzymeContext)
67
ref = API.CreateLogic()
7-
new(ref)
8+
GC.@preserve ctx begin
9+
API.LogicSetExternalContext(ref, Base.pointer_from_objref(ctx))
10+
return new(ref, ctx)
11+
end
812
end
913
end
1014
Base.unsafe_convert(::Type{API.EnzymeLogicRef}, logic::Logic) = logic.ref
1115
LLVM.dispose(logic::Logic) = API.FreeLogic(logic)
1216

17+
function enzyme_context(logic::Logic)
18+
return logic.ctx::EnzymeContext
19+
end
20+
21+
function enzyme_context(logic::API.EnzymeLogicRef)
22+
ptr = API.LogicGetExternalContext(logic)
23+
@assert ptr != C_NULL
24+
return unsafe_pointer_to_objref(ptr)::EnzymeContext
25+
end
26+
1327
# typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/,
1428
# CTypeTree * /*args*/, size_t /*numArgs*/,
1529
# LLVMValueRef)=T

0 commit comments

Comments
 (0)