Skip to content

Issue when differentiating a CUDA kernel ( only without KA ) #2947

@yolhan83

Description

@yolhan83

Hello, from this thread https://discourse.julialang.org/t/understanding-and-optimizing-enzyme-jl-reverse-ad-on-cuda/133357/4 it was highlight that some kernel need a shadow while it shouldn't.

PS : With N larger, seems to also randomly corupt julia at first compilation and make crashes afterwards (both version but may be unrelated), if we run it with a low N and only then use it with large N the problem disapear.

I made a better mwe :

using CUDA,Enzyme

function foo_ker!(res,x,theta)
    i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    @inbounds if i <= length(res)
        res[i] = x[i] + theta[i]
    end
    return
end

function foo!(res,x,theta)
    threads = 256
    blocks = cld(length(res),threads)
    @cuda threads=threads blocks=blocks foo_ker!(res,x,theta)
    CUDA.synchronize()
    return
end

function get_grad(res,x,theta)
    dres = Enzyme.make_zero(res) .+ 1
    dtheta = Enzyme.make_zero(theta)
    Enzyme.autodiff(Reverse,foo!,DuplicatedNoNeed(res,dres),Const(x),Duplicated(theta,dtheta))
    dtheta
end

function get_grad2(res,x,theta)
    dres = Enzyme.make_zero(res) .+ 1
    dtheta = Enzyme.make_zero(theta)
    dx = Enzyme.make_zero(x)
    Enzyme.autodiff(Reverse,foo!,DuplicatedNoNeed(res,dres),DuplicatedNoNeed(x,dx),Duplicated(theta,dtheta))
    dtheta
end

N = 100
res= CUDA.zeros(N);
x = CUDA.rand(N);
theta = CUDA.rand(N);

foo!(res,x,theta)

get_grad(res,x,theta) # fail 
get_grad2(res,x,theta) # works

error :

ERROR: type Const has no field dval
Stacktrace:
  [1] getproperty
    @ .\Base.jl:49 [inlined]
  [2] augmented_primal
    @ C:\Users\yolha\.julia\packages\CUDA\UurkZ\ext\EnzymeCoreExt.jl:99 [inlined]
  [3] map
    @ .\tuple.jl:357 [inlined]
  [4] macro expansion
    @ C:\Users\yolha\.julia\packages\CUDA\UurkZ\src\compiler\execution.jl:110 [inlined]
  [5] foo!
    @ c:\Users\yolha\Desktop\juju_tests\port_ml\main.jl:14 [inlined]
  [6] diffejulia_foo__81658wrap
    @ c:\Users\yolha\Desktop\juju_tests\port_ml\main.jl:0
  [7] macro expansion
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\compiler.jl:5883 [inlined]
  [8] enzyme_call
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\compiler.jl:5417 [inlined]
  [9] CombinedAdjointThunk
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\compiler.jl:5303 [inlined]
 [10] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\Enzyme.jl:521 [inlined]
 [11] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\Enzyme.jl:562 [inlined]
 [12] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\jbt7B\src\Enzyme.jl:534 [inlined]
 [13] get_grad(res::CuArray{…}, x::CuArray{…}, theta::CuArray{…})
    @ Main c:\Users\yolha\Desktop\juju_tests\port_ml\main.jl:22
 [14] top-level scope
    @ c:\Users\yolha\Desktop\juju_tests\port_ml\main.jl:41

pinging @wsmoses

versions :

julia> versioninfo()
Julia Version 1.11.7
Commit f2b3dbda30 (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 20 default, 0 interactive, 10 GC (on 20 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_VSCODE_REPL = 1
julia> CUDA.versioninfo()
CUDA toolchain: 
- runtime 12.9, artifact installation
- driver 577.0.0 for 12.9
- compiler 12.9

CUDA libraries:
- CUBLAS: 12.9.1
- CURAND: 10.3.10
- CUFFT: 11.4.1
- CUSOLVER: 11.7.5
- CUSPARSE: 12.5.10
- CUPTI: 2025.2.1 (API 12.9.1)
- NVML: 12.0.0+577.0

Julia packages:
- CUDA: 5.9.2
- CUDA_Driver_jll: 13.0.2+0
- CUDA_Compiler_jll: 0.3.0+0
- CUDA_Runtime_jll: 0.19.2+0

Toolchain:
- Julia: 1.11.7
- LLVM: 16.0.6

1 device:
  0: NVIDIA GeForce RTX 4060 Laptop GPU (sm_89, 7.025 GiB / 7.996 GiB available)
  [052768ef] CUDA v5.9.2
  [7da242da] Enzyme v0.13.89

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions