Skip to content

Commit fb6f959

Browse files
vchuravywsmoses
andauthored
Support Julia 1.11 (#1372)
* Test against v1.11 * WIP: adapt to 1.11 changes * fix constructor * Update interpreter.jl * add cache_token * fixup! add cache_token * Apply suggestions from code review * fixup! Apply suggestions from code review --------- Co-authored-by: William Moses <[email protected]>
1 parent bd60907 commit fb6f959

File tree

3 files changed

+66
-20
lines changed

3 files changed

+66
-20
lines changed

.github/workflows/CI.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
- '1.8'
2727
- '1.9'
2828
- '1.10'
29+
- ~1.11.0-0
2930
- 'nightly'
3031
os:
3132
- ubuntu-20.04
@@ -86,6 +87,11 @@ jobs:
8687
libEnzyme: packaged
8788
version: '1.10'
8889
assertions: true
90+
- os: ubuntu-20.04
91+
arch: x64
92+
libEnzyme: packaged
93+
version: '1.11'
94+
assertions: true
8995
steps:
9096
- uses: actions/checkout@v2
9197
- uses: julia-actions/setup-julia@v1
@@ -170,6 +176,7 @@ jobs:
170176
- '1.8'
171177
- '1.9'
172178
- '1.10'
179+
- ~1.11.0-0
173180
- 'nightly'
174181
os:
175182
- ubuntu-latest

src/compiler.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,8 +2992,27 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams})
29922992
GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme"
29932993

29942994
# provide a specific interpreter to use.
2995+
if VERSION >= v"1.11.0-DEV.1552"
2996+
struct EnzymeCacheToken
2997+
target_type::Type
2998+
always_inline
2999+
method_table::Core.MethodTable
3000+
param_type::Type
3001+
mode::API.CDerivativeMode
3002+
end
3003+
3004+
GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
3005+
EnzymeCacheToken(
3006+
typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job),
3007+
typeof(job.config.params), job.config.params.mode,
3008+
)
3009+
3010+
GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
3011+
Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode)
3012+
else
29953013
GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
29963014
Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode)
3015+
end
29973016

29983017
include("compiler/passes.jl")
29993018
include("compiler/optimize.jl")

src/compiler/interpreter.jl

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
module Interpreter
22
import Enzyme: API
33
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance
4-
using GPUCompiler: CodeCache, WorldView, @safe_debug
4+
using GPUCompiler: @safe_debug
5+
if VERSION < v"1.11.0-DEV.1552"
6+
using GPUCompiler: CodeCache, WorldView, @safe_debug
7+
end
8+
const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552"
9+
510
import ..Enzyme
611
import ..EnzymeRules
712

13+
@static if VERSION v"1.11.0-DEV.1498"
14+
import Core.Compiler: get_inference_world
15+
using Base: get_world_counter
16+
else
17+
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
18+
end
19+
820
struct EnzymeInterpreter <: AbstractInterpreter
9-
global_cache::CodeCache
21+
@static if HAS_INTEGRATED_CACHE
22+
token::Any
23+
else
24+
code_cache::CodeCache
25+
end
1026
method_table::Union{Nothing,Core.MethodTable}
1127

1228
# Cache of inference results for this particular interpreter
@@ -19,34 +35,38 @@ struct EnzymeInterpreter <: AbstractInterpreter
1935
opt_params::OptimizationParams
2036

2137
mode::API.CDerivativeMode
38+
end
2239

23-
function EnzymeInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode)
24-
@assert world <= Base.get_world_counter()
40+
function EnzymeInterpreter(cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode)
41+
@assert world <= Base.get_world_counter()
2542

26-
return new(
27-
cache,
28-
mt,
43+
return EnzymeInterpreter(
44+
cache_or_token,
45+
mt,
2946

30-
# Initially empty cache
31-
Vector{InferenceResult}(),
47+
# Initially empty cache
48+
Vector{InferenceResult}(),
3249

33-
# world age counter
34-
world,
50+
# world age counter
51+
world,
3552

36-
# parameters for inference and optimization
37-
InferenceParams(unoptimize_throw_blocks=false),
38-
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
39-
OptimizationParams(unoptimize_throw_blocks=false),
40-
mode
41-
)
42-
end
53+
# parameters for inference and optimization
54+
InferenceParams(unoptimize_throw_blocks=false),
55+
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
56+
OptimizationParams(unoptimize_throw_blocks=false),
57+
mode
58+
)
4359
end
4460

4561
Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params
4662
Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params
47-
Core.Compiler.get_world_counter(interp::EnzymeInterpreter) = interp.world
63+
get_inference_world(interp::EnzymeInterpreter) = interp.world
4864
Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache
49-
Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.global_cache, interp.world)
65+
@static if HAS_INTEGRATED_CACHE
66+
Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token
67+
else
68+
Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world)
69+
end
5070

5171
# No need to do any locking since we're not putting our results into the runtime cache
5272
Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing

0 commit comments

Comments
 (0)