Skip to content

Commit 6074c79

Browse files
committed
stabilize AD test
1 parent 40dba2e commit 6074c79

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

test/autodiff/ad.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -614,23 +614,23 @@ for V in spacelist
614614
return LinearAlgebra.tr(Str) + LinearAlgebra.norm(Utr * Vᴴtr)
615615
end
616616

617-
trunc = truncrank(round(Int, dim(V_trunc)))
618-
USVᴴ_trunc = svd_trunc(t; trunc)
617+
trunc = truncrank(ceil(Int, dim(V_trunc)))
618+
USVᴴ_trunc = svd_trunc(t; trunc)
619619
g1, = Zygote.gradient(x -> f(x; trunc), t)
620-
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
620+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
621621
@test g1 g2
622622

623623
trunc = truncerror(; atol = last(USVᴴ_trunc))
624-
USVᴴ_trunc = svd_trunc(t; trunc)
624+
USVᴴ_trunc = svd_trunc(t; trunc)
625625
g1, = Zygote.gradient(x -> f(x; trunc), t)
626-
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
626+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
627627
@test g1 g2
628628

629-
tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2]))
629+
tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2]); init = zero(scalartype(USVᴴ_trunc[2])))
630630
trunc = trunctol(; atol = 10 * tol)
631-
USVᴴ_trunc = svd_trunc(t; trunc)
631+
USVᴴ_trunc = svd_trunc(t; trunc)
632632
g1, = Zygote.gradient(x -> f(x; trunc), t)
633-
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
633+
g2, = Zygote.gradient(x -> f(x; trunc = truncspace(space(USVᴴ_trunc[2], 1))), t)
634634
@test g1 g2
635635
end
636636
end

0 commit comments

Comments
 (0)