From ee744bcfe1b5ff9d82dd0699378cc225700ea223 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 8 Jun 2022 19:47:32 +0100 Subject: [PATCH 1/3] add specialization for diff'ing through `Cholesky` with CuArrays --- src/lib/array.jl | 21 ++++++++++++++++++++- test/cuda.jl | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index f492af9e6..3862c8058 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -582,7 +582,7 @@ end C = cholesky(Σ, check = check) return C, function(Δ::NamedTuple) issuccess(C) || throw(PosDefException(C.info)) - U, Ū = C.U, Δ.factors + U, L, Ū = C.U, C.L, Δ.factors Σ̄ = similar(U.data) Σ̄ = mul!(Σ̄, Ū, U') Σ̄ = copytri!(Σ̄, 'U') @@ -593,6 +593,25 @@ end end end +@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin + Zygote.@adjoint function cholesky(Σ::Union{CUDA.CuMatrix, Symmetric{<:Real, <:CUDA.CuMatrix}}; check = true) + C = cholesky(Σ, check = check) + return C, function(Δ::NamedTuple) + issuccess(C) || throw(PosDefException(C.info)) + println(@__LINE__) + U, L, Ū = C.U, C.L, Δ.factors + Σ̄ = similar(U.data) + Σ̄ = mul!(Σ̄, Ū, U') + Σ̄ = copytri!(Σ̄, 'U') + Σ̄ = ldiv!(U, Σ̄) + Σ̄ = CUDA.CUBLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) + Σ̄[diagind(Σ̄)] ./= 2 + @info("", typeof(Σ̄), typeof(U), typeof(C)) + return (UpperTriangular(Σ̄),) + end + end +end + @adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix) X = lyap(A, C) return X, function (X̄) diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..623cf1347 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -140,3 +140,36 @@ end @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end +@testset "gpu cholesky" begin + rng, M, P, Q = MersenneTwister(123456), 13, 10, 9 + X, y = randn(rng, P, P), randn(rng, P) + dX, dy = cu(X), cu(y) + + f(X, y) = begin + C = cholesky(X*X' + I) + return sum(C \ y) + logdet(C) + end + + @test gradient(dy -> f(dX, dy), dy)[1] isa CUDA.CuArray{Float32} + @test gradient(dX -> f(dX, dy), dX)[1] isa CUDA.CuArray{Float32} + + ∇f_dev = Array(gradient(dy -> f(dX, dy), dy)[1]) + ∇f_cpu = gradient(y -> f(X, y), y)[1] + @test ∇f_dev ≈ ∇f_cpu + + ∇f_dev = Array(gradient(dX -> f(dX, dy), dX)[1]) + ∇f_cpu = gradient(X -> f(X, y), X)[1] + @test ∇f_dev ≈ ∇f_cpu + + @test_throws PosDefException Zygote.pullback(X -> cholesky(dX, check = false), X)[2]((factors=dX,)) + + # https://github.com/FluxML/Zygote.jl/issues/932 + # Symmetric is currenty not supported by CUDA.jl + # g(X, y) = begin + # C = cholesky(Symmetric(X*X' + I)) + # return sum(C \ y) + logdet(C) + # end + # ∇g_dev = gradient(dy -> g(dX, dy), dy)[1] + # ∇g_cpu = gradient(y -> g(X, y), y)[1] + # @test ∇g_dev ≈ ∇g_cpu +end From 2e73d840d120cb95fc9693a9772b84dc2c97068f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 8 Jun 2022 20:09:45 +0100 Subject: [PATCH 2/3] fix errors in GPU Cholesky test --- test/cuda.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 623cf1347..4a1753b75 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,7 +1,7 @@ using CUDA using Zygote: Grads using LinearAlgebra -using Random: randn! +using Random: randn!, MersenneTwister CUDA.allowscalar(false) # Test GPU movement inside the call to `gradient` @@ -140,14 +140,14 @@ end @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end -@testset "gpu cholesky" begin +@testset "GPU cholesky" begin rng, M, P, Q = MersenneTwister(123456), 13, 10, 9 X, y = randn(rng, P, P), randn(rng, P) dX, dy = cu(X), cu(y) f(X, y) = begin C = cholesky(X*X' + I) - return sum(C \ y) + logdet(C) + return sum(C \ y) end @test gradient(dy -> f(dX, dy), dy)[1] isa CUDA.CuArray{Float32} @@ -162,14 +162,4 @@ end @test ∇f_dev ≈ ∇f_cpu @test_throws PosDefException Zygote.pullback(X -> cholesky(dX, check = false), X)[2]((factors=dX,)) - - # https://github.com/FluxML/Zygote.jl/issues/932 - # Symmetric is currenty not supported by CUDA.jl - # g(X, y) = begin - # C = cholesky(Symmetric(X*X' + I)) - # return sum(C \ y) + logdet(C) - # end - # ∇g_dev = gradient(dy -> g(dX, dy), dy)[1] - # ∇g_cpu = gradient(y -> g(X, y), y)[1] - # @test ∇g_dev ≈ ∇g_cpu end From 5d0626c456bcef86e0b46ece6c5810e6cca58468 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 8 Jun 2022 20:13:24 +0100 Subject: [PATCH 3/3] remove debugging code in GPU Cholesky adjoint --- src/lib/array.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 3862c8058..74eefc4e6 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -598,7 +598,6 @@ end C = cholesky(Σ, check = check) return C, function(Δ::NamedTuple) issuccess(C) || throw(PosDefException(C.info)) - println(@__LINE__) U, L, Ū = C.U, C.L, Δ.factors Σ̄ = similar(U.data) Σ̄ = mul!(Σ̄, Ū, U') @@ -606,7 +605,6 @@ end Σ̄ = ldiv!(U, Σ̄) Σ̄ = CUDA.CUBLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) Σ̄[diagind(Σ̄)] ./= 2 - @info("", typeof(Σ̄), typeof(U), typeof(C)) return (UpperTriangular(Σ̄),) end end