Skip to content

Conversation

@vchuravy
Copy link
Member

@vchuravy vchuravy commented May 30, 2025

The crux for Enzyme GPU support is that we use a @generated deferred_codegen implementation in which we do not know what the calling environment is. We might be called from the CPU, CUDA, AMDGPU and so-forth.

GPUCompiler during the CUDA compilation then finds the Enzyme compilation job in the deferred_jobs dictionary,
and then asks Enzyme to codegen the adjoint code. During the code generation of the adjoint code, Enzyme must first codegen the primal/original code and thus must construct a compilation job for CUDA.

Previously we passed parent_job through for Enzyme to be able to perform the mode switch.

Here I propose that instead we support nesting both targets and params such that Enzyme can reuse those correctly instead of guessing.

Open to rename the function and I will add some tests here later. EnzymeAD/Enzyme.jl#2424 is the other side of this change.

x-ref: #668 (comment)

cc: @wsmoses

@codecov
Copy link

codecov bot commented May 30, 2025

Codecov Report

Attention: Patch coverage is 60.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 73.28%. Comparing base (8b8c73f) to head (009bbd3).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/interface.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #696      +/-   ##
==========================================
+ Coverage   71.63%   73.28%   +1.65%     
==========================================
  Files          24       24              
  Lines        3519     3523       +4     
==========================================
+ Hits         2521     2582      +61     
+ Misses        998      941      -57     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@vchuravy vchuravy marked this pull request as ready for review June 13, 2025 15:43
@vchuravy vchuravy force-pushed the vc/nested_targets branch from 9f205aa to 650b7f4 Compare June 13, 2025 15:43
@github-actions
Copy link
Contributor

github-actions bot commented Jun 13, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/driver.jl b/src/driver.jl
index f610611..e64791b 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -224,7 +224,7 @@ const __llvm_initialized = Ref(false)
                 dyn_entry_fn = get!(jobs, dyn_job) do
                     target = nest_target(dyn_job.config.target, job.config.target)
                     params = nest_params(dyn_job.config.params, job.config.params)
-                    config = CompilerConfig(dyn_job.config; toplevel=false, target, params)
+                    config = CompilerConfig(dyn_job.config; toplevel = false, target, params)
                     dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
                     dyn_entry_fn = LLVM.name(dyn_meta.entry)
                     merge!(compiled, dyn_meta.compiled)
diff --git a/test/helpers/enzyme.jl b/test/helpers/enzyme.jl
index 133e9a5..55547c9 100644
--- a/test/helpers/enzyme.jl
+++ b/test/helpers/enzyme.jl
@@ -2,12 +2,12 @@ module Enzyme
 
 using ..GPUCompiler
 
-struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget
+struct EnzymeTarget{Target <: AbstractCompilerTarget} <: AbstractCompilerTarget
     target::Target
 end
 
-function EnzymeTarget(;kwargs...)
-    EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
+function EnzymeTarget(; kwargs...)
+    return EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
 end
 
 GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target)
@@ -18,7 +18,7 @@ GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(targe
 GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target)
 
 abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
-struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams
+struct EnzymeCompilerParams{Params <: AbstractCompilerParams} <: AbstractEnzymeCompilerParams
     params::Params
 end
 struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
@@ -68,14 +68,18 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     sig = Tuple{ft, tt.parameters...}
     min_world = Ref{UInt}(typemin(UInt))
     max_world = Ref{UInt}(typemax(UInt))
-    match = ccall(:jl_gf_invoke_lookup_worlds, Any,
-                  (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
-                  sig, #=mt=# nothing, world, min_world, max_world)
+    match = ccall(
+        :jl_gf_invoke_lookup_worlds, Any,
+        (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
+        sig, #=mt=# nothing, world, min_world, max_world
+    )
     match === nothing && return stub(world, source, method_error)
 
     # look up the method and code instance
-    mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
-               (Any, Any, Any), match.method, match.spec_types, match.sparams)
+    mi = ccall(
+        :jl_specializations_get_linfo, Ref{Core.MethodInstance},
+        (Any, Any, Any), match.method, match.spec_types, match.sparams
+    )
     ci = CC.retrieve_code_info(mi, world)
 
     # prepare a new code info
@@ -83,10 +87,10 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     new_ci = copy(ci)
     empty!(new_ci.code)
     @static if isdefined(Core, :DebugInfo)
-      new_ci.debuginfo = Core.DebugInfo(:none)
+        new_ci.debuginfo = Core.DebugInfo(:none)
     else
-      empty!(new_ci.codelocs)
-      resize!(new_ci.linetable, 1)                # see note below
+        empty!(new_ci.codelocs)
+        resize!(new_ci.linetable, 1)                # see note below
     end
     empty!(new_ci.ssaflags)
     new_ci.ssavaluetypes = 0
@@ -99,7 +103,7 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
 
     # prepare the slots
     new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
-    new_ci.slotflags = UInt8[0x00 for i = 1:3]
+    new_ci.slotflags = UInt8[0x00 for i in 1:3]
     @static if isdefined(Core, :DebugInfo)
         new_ci.nargs = 3
     end
@@ -107,7 +111,7 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
     target = EnzymeTarget()
     params = EnzymeCompilerParams()
-    config = CompilerConfig(target, params; kernel=false)
+    config = CompilerConfig(target, params; kernel = false)
     job = CompilerJob(mi, config, world)
 
     id = length(deferred_codegen_jobs) + 1
@@ -116,9 +120,9 @@ function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::
     # return the deferred_codegen_id
     push!(new_ci.code, CC.ReturnNode(id))
     push!(new_ci.ssaflags, 0x00)
-        @static if isdefined(Core, :DebugInfo)
+    @static if isdefined(Core, :DebugInfo)
     else
-      push!(new_ci.codelocs, 1)   # see note below
+        push!(new_ci.codelocs, 1)   # see note below
     end
     new_ci.ssavaluetypes += 1
 
@@ -137,7 +141,7 @@ end
 
 @inline function deferred_codegen(f::Type, tt::Type)
     id = deferred_codegen_id(f, tt)
-    ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
+    return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
 end
 
-end
\ No newline at end of file
+end
diff --git a/test/native.jl b/test/native.jl
index cba496f..6122e15 100644
--- a/test/native.jl
+++ b/test/native.jl
@@ -659,14 +659,14 @@ end
         a[1] = a[1]^2
         return
     end
-    
+
     function dkernel(a)
         ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Vector{Float64}})
         ccall(ptr, Cvoid, (Vector{Float64},), a)
         return
     end
 
-    ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none))
+    ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo = :none))
     @test !occursin("deferred_codegen", ir)
     @test occursin("call void @julia_kernel", ir)
 end
diff --git a/test/ptx.jl b/test/ptx.jl
index 9e56ee5..c8e3a99 100644
--- a/test/ptx.jl
+++ b/test/ptx.jl
@@ -152,22 +152,22 @@ end
 end
 end
 
-@testset "Mock Enzyme" begin
-    function kernel(a)
-        unsafe_store!(a, unsafe_load(a)^2)
-        return
-    end
-    
-    function dkernel(a)
-        ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
-        ccall(ptr, Cvoid, (Ptr{Float64},), a)
-        return
-    end
+    @testset "Mock Enzyme" begin
+        function kernel(a)
+            unsafe_store!(a, unsafe_load(a)^2)
+            return
+        end
 
-    ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none))
-    @test !occursin("deferred_codegen", ir)
-    @test occursin("call void @julia_", ir)
-end
+        function dkernel(a)
+            ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
+            ccall(ptr, Cvoid, (Ptr{Float64},), a)
+            return
+        end
+
+        ir = sprint(io -> Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo = :none))
+        @test !occursin("deferred_codegen", ir)
+        @test occursin("call void @julia_", ir)
+    end
 
 end
 

@vchuravy vchuravy force-pushed the vc/nested_targets branch from 8a9ab5d to 1662181 Compare June 27, 2025 15:33
@vchuravy vchuravy merged commit afa9599 into master Jun 30, 2025
19 of 22 checks passed
@vchuravy vchuravy deleted the vc/nested_targets branch June 30, 2025 08:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants