Skip to content

Commit 815e8dd

Browse files
committed
fix test
1 parent 2abc33e commit 815e8dd

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

test/gradcheck.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,17 +819,31 @@ end
819819
@test back′(C̄)[1] isa Diagonal
820820
@test diag(back′(C̄)[1]) diag(back(C̄)[1])
821821
end
822-
@testset "cholesky - Hermitian" begin
822+
@testset "cholesky - Hermitian{Complex}" begin
823823
rng, N = MersenneTwister(123456), 3
824824
A = randn(rng, Complex{Float64}, N, N)
825825
H = Hermitian(A * A' + I)
826826
Hmat = Matrix(H)
827827
y, back = Zygote.pullback(cholesky, Hmat)
828828
y′, back′ = Zygote.pullback(cholesky, H)
829829
= (factors=randn(rng, N, N),)
830+
@test only(back′(C̄)) isa Hermitian
831+
# gradtest does not support complex gradients, even though the pullback exists
832+
d = only(back(C̄))
833+
d′ = only(back′(C̄))
834+
@test (d + d')/2 d′
835+
end
836+
@testset "cholesky - Hermitian{Real}" begin
837+
rng, N = MersenneTwister(123456), 3
838+
A = randn(rng, N, N)
839+
H = Hermitian(A * A' + I)
840+
Hmat = Matrix(H)
841+
y, back = Zygote.pullback(cholesky, Hmat)
842+
y′, back′ = Zygote.pullback(cholesky, H)
843+
= (factors=randn(rng, N, N),)
830844
@test back′(C̄)[1] isa Hermitian
831-
@test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I)
832-
@test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I)
845+
@test gradtest(B->cholesky(Hermitian(B)).U, Hmat)
846+
@test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat)
833847
end
834848
end
835849

0 commit comments

Comments
 (0)