Skip to content

Commit d2ed268

Browse files
committed
fix ad tests
1 parent 3e47462 commit d2ed268

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

test/autodiff/ad.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ function remove_qrgauge_dependence!(ΔQ, t, Q)
9090
end
9191
return ΔQ
9292
end
93-
9493
function remove_lqgauge_dependence!(ΔQ, t, Q)
9594
for (c, b) in blocks(ΔQ)
9695
m, n = size(block(t, c))
@@ -103,7 +102,7 @@ function remove_lqgauge_dependence!(ΔQ, t, Q)
103102
return ΔQ
104103
end
105104
function remove_eiggauge_dependence!(
106-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
105+
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
107106
)
108107
gaugepart = V' * ΔV
109108
for (c, b) in blocks(gaugepart)
@@ -119,9 +118,9 @@ function remove_eiggauge_dependence!(
119118
return ΔV
120119
end
121120
function remove_eighgauge_dependence!(
122-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
121+
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
123122
)
124-
gaugepart = V' * ΔV
123+
gaugepart = project_antihermitian!(V' * ΔV)
125124
gaugepart = (gaugepart - gaugepart') / 2
126125
for (c, b) in blocks(gaugepart)
127126
Dc = diagview(block(D, c))
@@ -136,10 +135,9 @@ function remove_eighgauge_dependence!(
136135
return ΔV
137136
end
138137
function remove_svdgauge_dependence!(
139-
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S)
138+
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S)
140139
)
141-
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
142-
gaugepart = (gaugepart - gaugepart') / 2
140+
gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ')
143141
for (c, b) in blocks(gaugepart)
144142
Sd = diagview(block(S, c))
145143
# for some reason this fails only on tests, and I cannot reproduce it in an
@@ -153,8 +151,6 @@ function remove_svdgauge_dependence!(
153151
return ΔU, ΔVᴴ
154152
end
155153

156-
project_hermitian(A) = (A + A') / 2
157-
158154
# Tests
159155
# -----
160156

@@ -601,7 +597,7 @@ for V in spacelist
601597
USVᴴ_trunc = svd_trunc(t; trunc)
602598
ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc)
603599
remove_svdgauge_dependence!(
604-
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol
600+
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
605601
)
606602
# test_ad_rrule(svd_trunc, t;
607603
# fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
@@ -610,7 +606,7 @@ for V in spacelist
610606
USVᴴ_trunc = svd_trunc(t; trunc)
611607
ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc)
612608
remove_svdgauge_dependence!(
613-
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol
609+
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
614610
)
615611
test_ad_rrule(
616612
svd_trunc, t;
@@ -631,7 +627,7 @@ for V in spacelist
631627
USVᴴ_trunc = svd_trunc(t; trunc)
632628
ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc)
633629
remove_svdgauge_dependence!(
634-
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; degeneracy_atol
630+
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
635631
)
636632
test_ad_rrule(
637633
svd_trunc, t;

0 commit comments

Comments
 (0)