1+ module Enzyme
2+
3+ using .. GPUCompiler
4+
5+ struct EnzymeTarget{Target<: AbstractCompilerTarget } <: AbstractCompilerTarget
6+ target:: Target
7+ end
8+
9+ function EnzymeTarget (;kwargs... )
10+ EnzymeTarget (GPUCompiler. NativeCompilerTarget (; jlruntime = true , kwargs... ))
11+ end
12+
13+ GPUCompiler. llvm_triple (target:: EnzymeTarget ) = GPUCompiler. llvm_triple (target. target)
14+ GPUCompiler. llvm_datalayout (target:: EnzymeTarget ) = GPUCompiler. llvm_datalayout (target. target)
15+ GPUCompiler. llvm_machine (target:: EnzymeTarget ) = GPUCompiler. llvm_machine (target. target)
16+ GPUCompiler. nest_target (:: EnzymeTarget , other:: AbstractCompilerTarget ) = EnzymeTarget (other)
17+ GPUCompiler. have_fma (target:: EnzymeTarget , T:: Type ) = GPUCompiler. have_fma (target. target, T)
18+ GPUCompiler. dwarf_version (target:: EnzymeTarget ) = GPUCompiler. dwarf_version (target. target)
19+
20+ abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
21+ struct EnzymeCompilerParams{Params<: AbstractCompilerParams } <: AbstractEnzymeCompilerParams
22+ params:: Params
23+ end
24+ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
25+ end
26+
27+ EnzymeCompilerParams () = EnzymeCompilerParams (PrimalCompilerParams ())
28+
29+ GPUCompiler. nest_params (:: EnzymeCompilerParams , other:: AbstractCompilerParams ) = EnzymeCompilerParams (other)
30+
31+ function GPUCompiler. compile_unhooked (output:: Symbol , job:: CompilerJob{<:EnzymeTarget} )
32+ config = job. config
33+ primal_target = (job. config. target:: EnzymeTarget ). target
34+ primal_params = (job. config. params:: EnzymeCompilerParams ). params
35+
36+ primal_config = CompilerConfig (
37+ primal_target,
38+ primal_params;
39+ toplevel = config. toplevel,
40+ always_inline = config. always_inline,
41+ kernel = false ,
42+ libraries = true ,
43+ optimize = false ,
44+ cleanup = false ,
45+ only_entry = false ,
46+ validate = false ,
47+ # ??? entry_abi
48+ )
49+ primal_job = CompilerJob (job. source, primal_config, job. world)
50+ return GPUCompiler. compile_unhooked (output, primal_job)
51+
52+ # Normally, Enzyme would run here and transform the output of the primal job.
53+ end
54+
55+ import GPUCompiler: deferred_codegen_jobs
56+ import Core. Compiler as CC
57+
58+ function deferred_codegen_id_generator (world:: UInt , source, self, ft:: Type , tt:: Type )
59+ @nospecialize
60+ @assert CC. isType (ft) && CC. isType (tt)
61+ ft = ft. parameters[1 ]
62+ tt = tt. parameters[1 ]
63+
64+ stub = Core. GeneratedFunctionStub (identity, Core. svec (:deferred_codegen_id , :ft , :tt ), Core. svec ())
65+
66+ # look up the method match
67+ method_error = :(throw (MethodError (ft, tt, $ world)))
68+ sig = Tuple{ft, tt. parameters... }
69+ min_world = Ref {UInt} (typemin (UInt))
70+ max_world = Ref {UInt} (typemax (UInt))
71+ match = ccall (:jl_gf_invoke_lookup_worlds , Any,
72+ (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
73+ sig, #= mt=# nothing , world, min_world, max_world)
74+ match === nothing && return stub (world, source, method_error)
75+
76+ # look up the method and code instance
77+ mi = ccall (:jl_specializations_get_linfo , Ref{Core. MethodInstance},
78+ (Any, Any, Any), match. method, match. spec_types, match. sparams)
79+ ci = CC. retrieve_code_info (mi, world)
80+
81+ # prepare a new code info
82+ new_ci = copy (ci)
83+ empty! (new_ci. code)
84+ empty! (new_ci. codelocs)
85+ empty! (new_ci. linetable)
86+ empty! (new_ci. ssaflags)
87+ new_ci. ssavaluetypes = 0
88+
89+ # propagate edge metadata
90+ new_ci. min_world = min_world[]
91+ new_ci. max_world = max_world[]
92+ new_ci. edges = Core. MethodInstance[mi]
93+
94+ # prepare the slots
95+ new_ci. slotnames = Symbol[Symbol (" #self#" ), :ft , :tt ]
96+ new_ci. slotflags = UInt8[0x00 for i = 1 : 3 ]
97+
98+ # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
99+ target = EnzymeTarget ()
100+ params = EnzymeCompilerParams ()
101+ config = CompilerConfig (target, params; kernel= false )
102+ job = CompilerJob (mi, config, world)
103+
104+ id = length (deferred_codegen_jobs) + 1
105+ deferred_codegen_jobs[id] = job
106+
107+ # return the deferred_codegen_id
108+ push! (new_ci. code, CC. ReturnNode (id))
109+ push! (new_ci. ssaflags, 0x00 )
110+ push! (new_ci. linetable, GPUCompiler. @LineInfoNode (methodinstance))
111+ push! (new_ci. codelocs, 1 )
112+ new_ci. ssavaluetypes += 1
113+
114+ return new_ci
115+ end
116+
117+ @eval function deferred_codegen_id (ft, tt)
118+ $ (Expr (:meta , :generated_only ))
119+ $ (Expr (:meta , :generated , deferred_codegen_id_generator))
120+ end
121+
122+ @inline function deferred_codegen (f:: Type , tt:: Type )
123+ id = deferred_codegen_id (f, tt)
124+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), id)
125+ end
126+
127+ end
0 commit comments