1+ module Enzyme
2+
3+ using .. GPUCompiler
4+
5+ struct EnzymeTarget{Target<: AbstractCompilerTarget } <: AbstractCompilerTarget
6+ target:: Target
7+ end
8+
9+ function EnzymeTarget (;kwargs... )
10+ EnzymeTarget (GPUCompiler. NativeCompilerTarget (; jlruntime = true , kwargs... ))
11+ end
12+
13+ GPUCompiler. llvm_triple (target:: EnzymeTarget ) = GPUCompiler. llvm_triple (target. target)
14+ GPUCompiler. llvm_datalayout (target:: EnzymeTarget ) = GPUCompiler. llvm_datalayout (target. target)
15+ GPUCompiler. llvm_machine (target:: EnzymeTarget ) = GPUCompiler. llvm_machine (target. target)
16+ GPUCompiler. nest_target (:: EnzymeTarget , other:: AbstractCompilerTarget ) = EnzymeTarget (other)
17+ GPUCompiler. have_fma (target:: EnzymeTarget , T:: Type ) = GPUCompiler. have_fma (target. target, T)
18+ GPUCompiler. dwarf_version (target:: EnzymeTarget ) = GPUCompiler. dwarf_version (target. target)
19+
20+ abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
21+ struct EnzymeCompilerParams{Params<: AbstractCompilerParams } <: AbstractEnzymeCompilerParams
22+ params:: Params
23+ end
24+ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
25+ end
26+
27+ EnzymeCompilerParams () = EnzymeCompilerParams (PrimalCompilerParams ())
28+
29+ GPUCompiler. nest_params (:: EnzymeCompilerParams , other:: AbstractCompilerParams ) = EnzymeCompilerParams (other)
30+
31+ function GPUCompiler. compile_unhooked (output:: Symbol , job:: CompilerJob{<:EnzymeTarget} )
32+ config = job. config
33+ primal_target = (job. config. target:: EnzymeTarget ). target
34+ primal_params = (job. config. params:: EnzymeCompilerParams ). params
35+
36+ primal_config = CompilerConfig (
37+ primal_target,
38+ primal_params;
39+ toplevel = config. toplevel,
40+ always_inline = config. always_inline,
41+ kernel = false ,
42+ libraries = true ,
43+ optimize = false ,
44+ cleanup = false ,
45+ only_entry = false ,
46+ validate = false ,
47+ # ??? entry_abi
48+ )
49+ primal_job = CompilerJob (job. source, primal_config, job. world)
50+ return GPUCompiler. compile_unhooked (output, primal_job)
51+
52+ # Normally, Enzyme would run here and transform the output of the primal job.
53+ end
54+
55+ import GPUCompiler: deferred_codegen_jobs
56+ import Core. Compiler as CC
57+
58+ function deferred_codegen_id_generator (world:: UInt , source, self, ft:: Type , tt:: Type )
59+ @nospecialize
60+ @assert CC. isType (ft) && CC. isType (tt)
61+ ft = ft. parameters[1 ]
62+ tt = tt. parameters[1 ]
63+
64+ stub = Core. GeneratedFunctionStub (identity, Core. svec (:deferred_codegen_id , :ft , :tt ), Core. svec ())
65+
66+ # look up the method match
67+ method_error = :(throw (MethodError (ft, tt, $ world)))
68+ sig = Tuple{ft, tt. parameters... }
69+ min_world = Ref {UInt} (typemin (UInt))
70+ max_world = Ref {UInt} (typemax (UInt))
71+ match = ccall (:jl_gf_invoke_lookup_worlds , Any,
72+ (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
73+ sig, #= mt=# nothing , world, min_world, max_world)
74+ match === nothing && return stub (world, source, method_error)
75+
76+ # look up the method and code instance
77+ mi = ccall (:jl_specializations_get_linfo , Ref{Core. MethodInstance},
78+ (Any, Any, Any), match. method, match. spec_types, match. sparams)
79+ ci = CC. retrieve_code_info (mi, world)
80+
81+ # prepare a new code info
82+ # TODO : Can we create a new CI instead of copying a "wrong" one?
83+ new_ci = copy (ci)
84+ empty! (new_ci. code)
85+ @static if isdefined (Core, :DebugInfo )
86+ new_ci. debuginfo = Core. DebugInfo (:none )
87+ else
88+ empty! (new_ci. codelocs)
89+ resize! (new_ci. linetable, 1 ) # see note below
90+ end
91+ empty! (new_ci. ssaflags)
92+ new_ci. ssavaluetypes = 0
93+
94+ # propagate edge metadata
95+ # new_ci.min_world = min_world[]
96+ new_ci. min_world = world
97+ new_ci. max_world = max_world[]
98+ new_ci. edges = Core. MethodInstance[mi]
99+
100+ # prepare the slots
101+ new_ci. slotnames = Symbol[Symbol (" #self#" ), :ft , :tt ]
102+ new_ci. slotflags = UInt8[0x00 for i = 1 : 3 ]
103+ @static if isdefined (Core, :DebugInfo )
104+ new_ci. nargs = 3
105+ end
106+
107+ # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
108+ target = EnzymeTarget ()
109+ params = EnzymeCompilerParams ()
110+ config = CompilerConfig (target, params; kernel= false )
111+ job = CompilerJob (mi, config, world)
112+
113+ id = length (deferred_codegen_jobs) + 1
114+ deferred_codegen_jobs[id] = job
115+
116+ # return the deferred_codegen_id
117+ push! (new_ci. code, CC. ReturnNode (id))
118+ push! (new_ci. ssaflags, 0x00 )
119+ @static if isdefined (Core, :DebugInfo )
120+ else
121+ push! (new_ci. codelocs, 1 ) # see note below
122+ end
123+ new_ci. ssavaluetypes += 1
124+
125+ # NOTE: we keep the first entry of the original linetable, and use it for location info
126+ # on the call to check_cache. we can't not have a codeloc (using 0 causes
127+ # corruption of the back trace), and reusing the target function's info
128+ # has as advantage that we see the name of the kernel in the backtraces.
129+
130+ return new_ci
131+ end
132+
133+ @eval function deferred_codegen_id (ft, tt)
134+ $ (Expr (:meta , :generated_only ))
135+ $ (Expr (:meta , :generated , deferred_codegen_id_generator))
136+ end
137+
138+ @inline function deferred_codegen (f:: Type , tt:: Type )
139+ id = deferred_codegen_id (f, tt)
140+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), id)
141+ end
142+
143+ end
0 commit comments