diff --git a/Project.toml b/Project.toml index e47535efbd..4bd13a772f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -37,6 +38,7 @@ GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" ObjectFile = "0.4" +PrecompileTools = "1.2" Preferences = "1.4" SpecialFunctions = "1, 2" StaticArrays = "1" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9ff56bdd81..12f7495907 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1475,4 +1475,13 @@ macro import_rrule(args...) return _import_rrule(args...) end +# using PrecompileTools +# Crashes on 1.11 +# @setup_workload let +# @compile_workload begin +# autodiff(ReverseMode{false,InlineABI,false}(), ()->nothing, Const) +# autodiff(ForwardMode{InlineABI}(), ()->nothing, Const) +# end +# end + end # module diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index e1652c5895..cc778e7d88 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -68,6 +68,20 @@ else Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) end +const CC = Core.Compiler +@static if HAS_INTEGRATED_CACHE + function CC.CodeInstance(interp::EnzymeInterpreter, result::CC.InferenceResult, + valid_worlds::CC.WorldRange) + ci = @invoke CC.CodeInstance(interp::CC.AbstractInterpreter, result::CC.InferenceResult, + valid_worlds::CC.WorldRange) + + # FIXME: Enzyme embeds global pointers and other fun things directly + # So forbid the caching of the results. + ci.relocatability = 0x0 + return ci + end +end + # No need to do any locking since we're not putting our results into the runtime cache Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing Core.Compiler.unlock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing diff --git a/test/Project.toml b/test/Project.toml index 5c8286d1af..fbdc4b6038 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/test/precompile.jl b/test/precompile.jl new file mode 100644 index 0000000000..98dfc95b1f --- /dev/null +++ b/test/precompile.jl @@ -0,0 +1,62 @@ +using Test + +function precompile_test_harness(@nospecialize(f), testset::String) + @testset "$testset" begin + precompile_test_harness(f, true) + end +end +function precompile_test_harness(@nospecialize(f), separate::Bool) + load_path = mktempdir() + load_cache_path = separate ? mktempdir() : load_path + try + pushfirst!(LOAD_PATH, load_path) + pushfirst!(DEPOT_PATH, load_cache_path) + f(load_path) + finally + try + rm(load_path, force=true, recursive=true) + catch err + @show err + end + if separate + try + rm(load_cache_path, force=true, recursive=true) + catch err + @show err + end + end + filter!((≠)(load_path), LOAD_PATH) + separate && filter!((≠)(load_cache_path), DEPOT_PATH) + end + nothing +end + +precompile_test_harness("Inference caching") do load_path + write(joinpath(load_path, "InferenceCaching.jl"), :(module InferenceCaching + using Enzyme + using PrecompileTools + + function mul(x, y) + return x * y + end + + @setup_workload begin + @compile_workload begin + autodiff(ReverseMode{false,InlineABI,false}(), mul, Active, Active(1.0), Active(2.0)) + # Non-Inline mode uses `@generated` functions and poisons the caller + # autodiff(Reverse, mul, Active, Active(1.0), Active(2.0)) + # autodiff(Forward, mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + end + end + end) |> string) + + Base.compilecache(Base.PkgId("InferenceCaching")) + @eval let + using InferenceCaching + using Enzyme + + @test autodiff(ReverseMode{false,InlineABI,false}(), InferenceCaching.mul, Active, Active(1.0), Active(2.0)) == ((2.0, 1.0),) + # autodiff(Reverse, InferenceCaching.mul, Active, Active(1.0), Active(2.0)) + # autodiff(Forward, InferenceCaching.mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + end +end \ No newline at end of file