Skip to content

Commit ff3ccc3

Browse files
committed
add mock Enzyme tests
1 parent db6f2d3 commit ff3ccc3

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

test/helpers/enzyme.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
module Enzyme
2+
3+
using ..GPUCompiler
4+
5+
struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget
6+
target::Target
7+
end
8+
9+
function EnzymeTarget(;kwargs...)
10+
EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
11+
end
12+
13+
GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target)
14+
GPUCompiler.llvm_datalayout(target::EnzymeTarget) = GPUCompiler.llvm_datalayout(target.target)
15+
GPUCompiler.llvm_machine(target::EnzymeTarget) = GPUCompiler.llvm_machine(target.target)
16+
GPUCompiler.nest_target(::EnzymeTarget, other::AbstractCompilerTarget) = EnzymeTarget(other)
17+
GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(target.target, T)
18+
GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target)
19+
20+
abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
21+
struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams
22+
params::Params
23+
end
24+
struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
25+
end
26+
27+
EnzymeCompilerParams() = EnzymeCompilerParams(PrimalCompilerParams())
28+
29+
GPUCompiler.nest_params(::EnzymeCompilerParams, other::AbstractCompilerParams) = EnzymeCompilerParams(other)
30+
31+
function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeTarget})
32+
config = job.config
33+
primal_target = (job.config.target::EnzymeTarget).target
34+
primal_params = (job.config.params::EnzymeCompilerParams).params
35+
36+
primal_config = CompilerConfig(
37+
primal_target,
38+
primal_params;
39+
toplevel = config.toplevel,
40+
always_inline = config.always_inline,
41+
kernel = false,
42+
libraries = true,
43+
optimize = false,
44+
cleanup = false,
45+
only_entry = false,
46+
validate = false,
47+
# ??? entry_abi
48+
)
49+
primal_job = CompilerJob(job.source, primal_config, job.world)
50+
return GPUCompiler.compile_unhooked(output, primal_job)
51+
52+
# Normally, Enzyme would run here and transform the output of the primal job.
53+
end
54+
55+
import GPUCompiler: deferred_codegen_jobs
56+
import Core.Compiler as CC
57+
58+
function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::Type)
59+
@nospecialize
60+
@assert CC.isType(ft) && CC.isType(tt)
61+
ft = ft.parameters[1]
62+
tt = tt.parameters[1]
63+
64+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:deferred_codegen_id, :ft, :tt), Core.svec())
65+
66+
# look up the method match
67+
method_error = :(throw(MethodError(ft, tt, $world)))
68+
sig = Tuple{ft, tt.parameters...}
69+
min_world = Ref{UInt}(typemin(UInt))
70+
max_world = Ref{UInt}(typemax(UInt))
71+
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
72+
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
73+
sig, #=mt=# nothing, world, min_world, max_world)
74+
match === nothing && return stub(world, source, method_error)
75+
76+
# look up the method and code instance
77+
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
78+
(Any, Any, Any), match.method, match.spec_types, match.sparams)
79+
ci = CC.retrieve_code_info(mi, world)
80+
81+
# prepare a new code info
82+
new_ci = copy(ci)
83+
empty!(new_ci.code)
84+
empty!(new_ci.codelocs)
85+
empty!(new_ci.linetable)
86+
empty!(new_ci.ssaflags)
87+
new_ci.ssavaluetypes = 0
88+
89+
# propagate edge metadata
90+
new_ci.min_world = min_world[]
91+
new_ci.max_world = max_world[]
92+
new_ci.edges = Core.MethodInstance[mi]
93+
94+
# prepare the slots
95+
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
96+
new_ci.slotflags = UInt8[0x00 for i = 1:3]
97+
98+
# We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
99+
target = EnzymeTarget()
100+
params = EnzymeCompilerParams()
101+
config = CompilerConfig(target, params; kernel=false)
102+
job = CompilerJob(mi, config, world)
103+
104+
id = length(deferred_codegen_jobs) + 1
105+
deferred_codegen_jobs[id] = job
106+
107+
# return the deferred_codegen_id
108+
push!(new_ci.code, CC.ReturnNode(id))
109+
push!(new_ci.ssaflags, 0x00)
110+
push!(new_ci.linetable, GPUCompiler.@LineInfoNode(methodinstance))
111+
push!(new_ci.codelocs, 1)
112+
new_ci.ssavaluetypes += 1
113+
114+
return new_ci
115+
end
116+
117+
@eval function deferred_codegen_id(ft, tt)
118+
$(Expr(:meta, :generated_only))
119+
$(Expr(:meta, :generated, deferred_codegen_id_generator))
120+
end
121+
122+
@inline function deferred_codegen(f::Type, tt::Type)
123+
id = deferred_codegen_id(f, tt)
124+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
125+
end
126+
127+
end

test/native.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,20 @@ end
561561
["jl_invoke", "apply_iterate",
562562
"inttoptr", "apply_type"])
563563
end
564+
565+
@testset "Mock Enzyme" begin
566+
function kernel(a)
567+
a[1] = a[1]^2
568+
return
569+
end
570+
571+
function dkernel(a)
572+
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Vector{Float64}})
573+
ccall(ptr, Cvoid, (Vector{Float64},), a)
574+
return
575+
end
576+
577+
ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none))
578+
@test !occursin("deferred_codegen", ir)
579+
@test occursin("call void @julia_kernel", ir)
580+
end

test/ptx.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ end
121121
end
122122
end
123123

124+
@testset "Mock Enzyme" begin
125+
function kernel(a)
126+
unsafe_store!(a, unsafe_load(a)^2)
127+
return
128+
end
129+
130+
function dkernel(a)
131+
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
132+
ccall(ptr, Cvoid, (Ptr{Float64},), a)
133+
return
134+
end
135+
136+
ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none))
137+
@test !occursin("deferred_codegen", ir)
138+
@test occursin("call void @julia_kernel", ir)
139+
end
140+
124141
end
125142

126143
############################################################################################

0 commit comments

Comments
 (0)