11module Interpreter
22import Enzyme: API
33using Core. Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance
4- using GPUCompiler: CodeCache, WorldView, @safe_debug
4+ using GPUCompiler: @safe_debug
5+ if VERSION < v " 1.11.0-DEV.1552"
6+ using GPUCompiler: CodeCache, WorldView, @safe_debug
7+ end
8+ const HAS_INTEGRATED_CACHE = VERSION >= v " 1.11.0-DEV.1552"
9+
510import .. Enzyme
611import .. EnzymeRules
712
13+ @static if VERSION ≥ v " 1.11.0-DEV.1498"
14+ import Core. Compiler: get_inference_world
15+ using Base: get_world_counter
16+ else
17+ import Core. Compiler: get_world_counter, get_world_counter as get_inference_world
18+ end
19+
820struct EnzymeInterpreter <: AbstractInterpreter
9- global_cache:: CodeCache
21+ @static if HAS_INTEGRATED_CACHE
22+ token:: Any
23+ else
24+ code_cache:: CodeCache
25+ end
1026 method_table:: Union{Nothing,Core.MethodTable}
1127
1228 # Cache of inference results for this particular interpreter
@@ -19,34 +35,38 @@ struct EnzymeInterpreter <: AbstractInterpreter
1935 opt_params:: OptimizationParams
2036
2137 mode:: API.CDerivativeMode
38+ end
2239
23- function EnzymeInterpreter (cache :: CodeCache , mt:: Union{Nothing,Core.MethodTable} , world:: UInt , mode:: API.CDerivativeMode )
24- @assert world <= Base. get_world_counter ()
40+ function EnzymeInterpreter (cache_or_token , mt:: Union{Nothing,Core.MethodTable} , world:: UInt , mode:: API.CDerivativeMode )
41+ @assert world <= Base. get_world_counter ()
2542
26- return new (
27- cache ,
28- mt,
43+ return EnzymeInterpreter (
44+ cache_or_token ,
45+ mt,
2946
30- # Initially empty cache
31- Vector {InferenceResult} (),
47+ # Initially empty cache
48+ Vector {InferenceResult} (),
3249
33- # world age counter
34- world,
50+ # world age counter
51+ world,
3552
36- # parameters for inference and optimization
37- InferenceParams (unoptimize_throw_blocks= false ),
38- VERSION >= v " 1.8.0-DEV.486" ? OptimizationParams () :
39- OptimizationParams (unoptimize_throw_blocks= false ),
40- mode
41- )
42- end
53+ # parameters for inference and optimization
54+ InferenceParams (unoptimize_throw_blocks= false ),
55+ VERSION >= v " 1.8.0-DEV.486" ? OptimizationParams () :
56+ OptimizationParams (unoptimize_throw_blocks= false ),
57+ mode
58+ )
4359end
4460
4561Core. Compiler. InferenceParams (interp:: EnzymeInterpreter ) = interp. inf_params
4662Core. Compiler. OptimizationParams (interp:: EnzymeInterpreter ) = interp. opt_params
47- Core . Compiler . get_world_counter (interp:: EnzymeInterpreter ) = interp. world
63+ get_inference_world (interp:: EnzymeInterpreter ) = interp. world
4864Core. Compiler. get_inference_cache (interp:: EnzymeInterpreter ) = interp. local_cache
49- Core. Compiler. code_cache (interp:: EnzymeInterpreter ) = WorldView (interp. global_cache, interp. world)
65+ @static if HAS_INTEGRATED_CACHE
66+ Core. Compiler. cache_owner (interp:: EnzymeInterpreter ) = interp. token
67+ else
68+ Core. Compiler. code_cache (interp:: EnzymeInterpreter ) = WorldView (interp. code_cache, interp. world)
69+ end
5070
5171# No need to do any locking since we're not putting our results into the runtime cache
5272Core. Compiler. lock_mi_inference (interp:: EnzymeInterpreter , mi:: MethodInstance ) = nothing
0 commit comments