|
654 | 654 | g(X) = cholesky(X * X' + I) |
655 | 655 | @test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈ |
656 | 656 | Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1] |
657 | | - @test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,)) |
658 | 657 |
|
659 | 658 | # https://github.com/FluxML/Zygote.jl/issues/932 |
660 | 659 | @test gradcheck(rand(5, 5), rand(5)) do A, x |
|
820 | 819 | @test back′(C̄)[1] isa Diagonal |
821 | 820 | @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) |
822 | 821 | end |
| 822 | + @testset "cholesky - Hermitian{Complex}" begin |
| 823 | + rng, N = MersenneTwister(123456), 3 |
| 824 | + A = randn(rng, Complex{Float64}, N, N) |
| 825 | + H = Hermitian(A * A' + I) |
| 826 | + Hmat = Matrix(H) |
| 827 | + y, back = Zygote.pullback(cholesky, Hmat) |
| 828 | + y′, back′ = Zygote.pullback(cholesky, H) |
| 829 | + C̄ = (factors=randn(rng, N, N),) |
| 830 | + @test only(back′(C̄)) isa Hermitian |
| 831 | + # gradtest does not support complex gradients, even though the pullback exists |
| 832 | + d = only(back(C̄)) |
| 833 | + d′ = only(back′(C̄)) |
| 834 | + @test (d + d')/2 ≈ d′ |
| 835 | + end |
| 836 | + @testset "cholesky - Hermitian{Real}" begin |
| 837 | + rng, N = MersenneTwister(123456), 3 |
| 838 | + A = randn(rng, N, N) |
| 839 | + H = Hermitian(A * A' + I) |
| 840 | + Hmat = Matrix(H) |
| 841 | + y, back = Zygote.pullback(cholesky, Hmat) |
| 842 | + y′, back′ = Zygote.pullback(cholesky, H) |
| 843 | + C̄ = (factors=randn(rng, N, N),) |
| 844 | + @test back′(C̄)[1] isa Hermitian |
| 845 | + @test gradtest(B->cholesky(Hermitian(B)).U, Hmat) |
| 846 | + @test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat) |
| 847 | + end |
823 | 848 | end |
824 | 849 |
|
825 | 850 | @testset "lyap" begin |
|
0 commit comments