Skip to content

Commit 358d647

Browse files
authored
Precompilation is cool, we should do more of it (#2160)
* Precompilation is cool, we should do more of it * fix * tm stuff * ix attempt * reset * more * ix * reduce * fix
1 parent 3ad827f commit 358d647

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1212
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92"
15+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1516
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1617
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/Enzyme.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,4 +1587,6 @@ Returns true if within autodiff, otherwise false.
15871587
"""
15881588
@inline EnzymeCore.within_autodiff() = false
15891589

1590+
include("precompile.jl")
1591+
15901592
end # module

src/compiler/orcv2.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function define_absolute_symbol(jd, name)
8383
return false
8484
end
8585

86-
function __init__()
86+
function setup_globals()
8787
opt_level = Base.JLOptions().opt_level
8888
if opt_level < 2
8989
optlevel = LLVM.API.LLVMCodeGenLevelNone
@@ -105,11 +105,6 @@ function __init__()
105105
dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
106106
LLVM.add!(jd_main, dg)
107107

108-
if Sys.iswindows() && Int === Int64
109-
# TODO can we check isGNU?
110-
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
111-
end
112-
113108
es = ExecutionSession(lljit)
114109
try
115110
lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es)
@@ -120,6 +115,17 @@ function __init__()
120115
jit[] = CompilerInstance(lljit, nothing, nothing)
121116
end
122117

118+
jd_main, lljit
119+
end
120+
121+
function __init__()
122+
jd_main, lljit = setup_globals()
123+
124+
if Sys.iswindows() && Int === Int64
125+
# TODO can we check isGNU?
126+
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
127+
end
128+
123129
hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid}))
124130
for (k, v) in Compiler.JuliaGlobalNameMap
125131
ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k)))

src/precompile.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using PrecompileTools: @setup_workload, @compile_workload
2+
3+
@setup_workload begin
4+
precompile_module = @eval module $(gensym())
5+
f(x) = x^2
6+
end
7+
8+
Compiler.JIT.setup_globals()
9+
10+
@compile_workload begin
11+
Enzyme.autodiff(Reverse, precompile_module.f, Active(2.0))
12+
end
13+
end

0 commit comments

Comments
 (0)