Skip to content

Commit 4bdb660

Browse files
committed
try and add back some AD tests
1 parent 192406a commit 4bdb660

File tree

1 file changed

+26
-29
lines changed

1 file changed

+26
-29
lines changed

test/autodiff/ad.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -593,19 +593,12 @@ for V in spacelist
593593
test_ad_rrule(svd_compact, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol)
594594
test_ad_rrule(svd_compact, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol)
595595

596-
# TODO: I'm not sure how to properly test with spaces that might change
597-
# with the finite-difference methods, as then the jacobian is ill-defined.
598-
599-
trunc = truncrank(max(2, round(Int, min(dim(domain(t)), dim(codomain(t))) * (3 / 4))))
600-
USVᴴ_trunc = svd_trunc(t; trunc)
601-
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
602-
remove_svdgauge_dependence!(
603-
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
604-
)
605-
# test_ad_rrule(svd_trunc, t;
606-
# fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
607-
608-
trunc = truncspace(space(USVᴴ_trunc[2], 1))
596+
# Testing truncation with finitedifferences is RNG-prone since the
597+
# Jacobian changes size if the truncation space changes, causing errors.
598+
# So, first test the fixed space case, then do more limited testing on
599+
# some gradients and compare to the fixed space case
600+
V_trunc = spacetype(t)(c => ceil(Int, min(size(b)...) / 2) for (c, b) in blocks(t))
601+
trunc = truncspace(V_trunc)
609602
USVᴴ_trunc = svd_trunc(t; trunc)
610603
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
611604
remove_svdgauge_dependence!(
@@ -616,26 +609,30 @@ for V in spacelist
616609
fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol
617610
)
618611

619-
# ϵ = norm(*(USVᴴ_trunc...) - t)
620-
# trunc = truncerror(; atol=ϵ)
621-
# USVᴴ_trunc = svd_trunc(t; trunc)
622-
# ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc)
623-
# remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...;
624-
# degeneracy_atol)
625-
# test_ad_rrule(svd_trunc, t;
626-
# fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
612+
# attempt to construct a loss function that doesn't depend on the gauges
613+
function f(t; trunc)
614+
Utr, Str, Vᴴtr, ϵ = svd_trunc(t; trunc)
615+
return LinearAlgebra.tr(Str) + LinearAlgebra.norm(Utr * Vᴴtr)
616+
end
617+
618+
trunc = truncrank(round(Int, dim(V_trunc)))
619+
USVᴴ_trunc = svd_trunc(t; trunc)
620+
g1, = Zygote.gradient(x -> f(x; trunc), t)
621+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
622+
@test g1 g2
623+
624+
trunc = truncerror(; atol = last(USVᴴ_trunc))
625+
USVᴴ_trunc = svd_trunc(t; trunc)
626+
g1, = Zygote.gradient(x -> f(x; trunc), t)
627+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
628+
@test g1 g2
627629

628630
tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2]))
629631
trunc = trunctol(; atol = 10 * tol)
630632
USVᴴ_trunc = svd_trunc(t; trunc)
631-
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
632-
remove_svdgauge_dependence!(
633-
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
634-
)
635-
# test_ad_rrule(
636-
# svd_trunc, t;
637-
# fkwargs = (; trunc), output_tangent = ΔUSVᴴ_trunc, atol, rtol
638-
# )
633+
g1, = Zygote.gradient(x -> f(x; trunc), t)
634+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
635+
@test g1 g2
639636
end
640637
end
641638

0 commit comments

Comments
 (0)