@@ -614,23 +614,23 @@ for V in spacelist
614614 return LinearAlgebra. tr (Str) + LinearAlgebra. norm (Utr * Vᴴtr)
615615 end
616616
617- trunc = truncrank (round (Int, dim (V_trunc)))
618- USVᴴ_trunc = svd_trunc (t; trunc)
617+ trunc = truncrank (ceil (Int, dim (V_trunc)))
618+ USVᴴ_trunc′ = svd_trunc (t; trunc)
619619 g1, = Zygote. gradient (x -> f (x; trunc), t)
620- g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
620+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc′ [2 ], 1 ))), t)
621621 @test g1 ≈ g2
622622
623623 trunc = truncerror (; atol = last (USVᴴ_trunc))
624- USVᴴ_trunc = svd_trunc (t; trunc)
624+ USVᴴ_trunc′ = svd_trunc (t; trunc)
625625 g1, = Zygote. gradient (x -> f (x; trunc), t)
626- g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
626+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc′ [2 ], 1 ))), t)
627627 @test g1 ≈ g2
628628
629- tol = minimum (((c, b),) -> minimum (diagview (b)), blocks (USVᴴ_trunc[2 ]))
629+ tol = minimum (((c, b),) -> minimum (diagview (b)), blocks (USVᴴ_trunc[2 ]); init = zero ( scalartype (USVᴴ_trunc[ 2 ])) )
630630 trunc = trunctol (; atol = 10 * tol)
631- USVᴴ_trunc = svd_trunc (t; trunc)
631+ USVᴴ_trunc′ = svd_trunc (t; trunc)
632632 g1, = Zygote. gradient (x -> f (x; trunc), t)
633- g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
633+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc′ [2 ], 1 ))), t)
634634 @test g1 ≈ g2
635635 end
636636 end
0 commit comments