@@ -593,19 +593,12 @@ for V in spacelist
593593 test_ad_rrule (svd_compact, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol)
594594 test_ad_rrule (svd_compact, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol)
595595
596- # TODO : I'm not sure how to properly test with spaces that might change
597- # with the finite-difference methods, as then the jacobian is ill-defined.
598-
599- trunc = truncrank (max (2 , round (Int, min (dim (domain (t)), dim (codomain (t))) * (3 / 4 ))))
600- USVᴴ_trunc = svd_trunc (t; trunc)
601- ΔUSVᴴ_trunc = (rand_tangent .(Base. front (USVᴴ_trunc))... , zero (last (USVᴴ_trunc)))
602- remove_svdgauge_dependence! (
603- ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], Base. front (USVᴴ_trunc)... ; degeneracy_atol
604- )
605- # test_ad_rrule(svd_trunc, t;
606- # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
607-
608- trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))
596+ # Testing truncation with finitedifferences is RNG-prone since the
597+ # Jacobian changes size if the truncation space changes, causing errors.
598+ # So, first test the fixed space case, then do more limited testing on
599+ # some gradients and compare to the fixed space case
600+ V_trunc = spacetype (t)(c => ceil (Int, min (size (b)... ) / 2 ) for (c, b) in blocks (t))
601+ trunc = truncspace (V_trunc)
609602 USVᴴ_trunc = svd_trunc (t; trunc)
610603 ΔUSVᴴ_trunc = (rand_tangent .(Base. front (USVᴴ_trunc))... , zero (last (USVᴴ_trunc)))
611604 remove_svdgauge_dependence! (
@@ -616,26 +609,30 @@ for V in spacelist
616609 fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol
617610 )
618611
619- # ϵ = norm(*(USVᴴ_trunc...) - t)
620- # trunc = truncerror(; atol=ϵ)
621- # USVᴴ_trunc = svd_trunc(t; trunc)
622- # ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc)
623- # remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...;
624- # degeneracy_atol)
625- # test_ad_rrule(svd_trunc, t;
626- # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
612+ # attempt to construct a loss function that doesn't depend on the gauges
613+ function f (t; trunc)
614+ Utr, Str, Vᴴtr, ϵ = svd_trunc (t; trunc)
615+ return LinearAlgebra. tr (Str) + LinearAlgebra. norm (Utr * Vᴴtr)
616+ end
617+
618+ trunc = truncrank (round (Int, dim (V_trunc)))
619+ USVᴴ_trunc = svd_trunc (t; trunc)
620+ g1, = Zygote. gradient (x -> f (x; trunc), t)
621+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
622+ @test g1 ≈ g2
623+
624+ trunc = truncerror (; atol = last (USVᴴ_trunc))
625+ USVᴴ_trunc = svd_trunc (t; trunc)
626+ g1, = Zygote. gradient (x -> f (x; trunc), t)
627+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
628+ @test g1 ≈ g2
627629
628630 tol = minimum (((c, b),) -> minimum (diagview (b)), blocks (USVᴴ_trunc[2 ]))
629631 trunc = trunctol (; atol = 10 * tol)
630632 USVᴴ_trunc = svd_trunc (t; trunc)
631- ΔUSVᴴ_trunc = (rand_tangent .(Base. front (USVᴴ_trunc))... , zero (last (USVᴴ_trunc)))
632- remove_svdgauge_dependence! (
633- ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], Base. front (USVᴴ_trunc)... ; degeneracy_atol
634- )
635- # test_ad_rrule(
636- # svd_trunc, t;
637- # fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol
638- # )
633+ g1, = Zygote. gradient (x -> f (x; trunc), t)
634+ g2, = Zygote. gradient (x -> f (x; trunc = truncspace (space (USVᴴ_trunc[2 ], 1 ))), t)
635+ @test g1 ≈ g2
639636 end
640637 end
641638
0 commit comments