Skip to content

Adjoint of cholesky is hard-coded for the CPU  #1210

@Red-Portal

Description

@Red-Portal

Hi,

I've been attempting to differentiate through a Cholesky decomposition, which is common practice in Gaussian processes. The problem is that, the current adjoint for the Cholesky is hard-coded for the CPU version of trsm!.

See the following minimal working example:

using CUDA
using KernelAbstractions
using CUDAKernels
using LinearAlgebra

import Tullio
import Zygote

function main()
    N = 1024
    D = 16
    X = randn(Float32, D, N)
    y = randn(Float32, N)

    CUDA.allowscalar(true)
    X_dev = CuArray(X)
    y_dev = CuArray(y)
    @time begin
        ∇K = Zygote.gradient(cu(randn(Float32, D+2))) do θ
            ℓα     = θ[1:1]
            ℓϵ     = θ[2]
            logℓ   = θ[3:end]
            Tullio.@tullio K[i,j] := exp(ℓα[1]*2 - (X_dev[k,i] - X_dev[k,j])^2 / exp(2*logℓ[k])) verbose=true
            K_ϵ      = K + cu(exp(ℓϵ)*I)
            K_ϵ_chol = cholesky(K_ϵ)
            α        = K_ϵ_chol \ y_dev
            dot(α, y_dev)
        end
    end
end

main()

output:

ERROR: ArgumentError: cannot take the CPU address of a CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
Stacktrace:
  [1] unsafe_convert(#unused#::Type{Ptr{Float32}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ CUDA ~/.julia/packages/CUDA/5jdFl/src/array.jl:315
  [2] trsm!(side::Char, uplo::Char, transa::Char, diag::Char, alpha::Float32, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LinearAlgebra.BLAS /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/blas.jl:1958
  [3] (::Zygote.var"#817#818"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Cholesky{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::NamedTuple{(:uplo, :info, :factors), Tuple{Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:603
  [4] (::Zygote.var"#3217#back#819"{Zygote.var"#817#818"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Cholesky{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}})(Δ::NamedTuple{(:uplo, :info, :factors), Tuple{Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./REPL[33]:17 [inlined]
  [6] (::typeof((λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#56#57"{typeof((λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
  [8] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76

A simple fix is to use the following snippete:

@eval Zygote begin
    import CUDA
    @adjoint function cholesky::CUDA.CuArray; check = true)
        C = cholesky(Σ, check = check)
        C, function::NamedTuple)
            issuccess(C) || throw(PosDefException(C.info))
            U, Ū = C.U, Δ.factors

            U_tru = triu(U.data)
            Ū_tru = triu.data)

            Σ̄ = similar(U.data)
            Σ̄ = mul!(Σ̄, Ū_tru, U_tru')
            Σ̄ = copytri!(Σ̄, 'U')
            Σ̄ = ldiv!(U, Σ̄)
            Σ̄ = CUDA.CUBLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
            Σ̄[diagind(Σ̄)] ./= 2
            return (UpperTriangular(Σ̄),)
        end
    end
end

The two calls to triu are necessary for going around a performance bug in the matrix multiplication between two triangular matrices. I didn't pursue the cause further, but it seems that multiplying two triangular matrices on the GPU is like a 100 times slower than a simple matrix multiplication. Any thoughts on the reason for this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions