@@ -90,7 +90,6 @@ function remove_qrgauge_dependence!(ΔQ, t, Q)
9090 end
9191 return ΔQ
9292end
93-
9493function remove_lqgauge_dependence! (ΔQ, t, Q)
9594 for (c, b) in blocks (ΔQ)
9695 m, n = size (block (t, c))
@@ -103,7 +102,7 @@ function remove_lqgauge_dependence!(ΔQ, t, Q)
103102 return ΔQ
104103end
105104function remove_eiggauge_dependence! (
106- ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_gauge_atol (D)
105+ ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (D)
107106 )
108107 gaugepart = V' * ΔV
109108 for (c, b) in blocks (gaugepart)
@@ -119,9 +118,9 @@ function remove_eiggauge_dependence!(
119118 return ΔV
120119end
121120function remove_eighgauge_dependence! (
122- ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_gauge_atol (D)
121+ ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (D)
123122 )
124- gaugepart = V' * ΔV
123+ gaugepart = project_antihermitian! ( V' * ΔV)
125124 gaugepart = (gaugepart - gaugepart' ) / 2
126125 for (c, b) in blocks (gaugepart)
127126 Dc = diagview (block (D, c))
@@ -136,10 +135,9 @@ function remove_eighgauge_dependence!(
136135 return ΔV
137136end
138137function remove_svdgauge_dependence! (
139- ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit. default_pullback_gauge_atol (S)
138+ ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (S)
140139 )
141- gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
142- gaugepart = (gaugepart - gaugepart' ) / 2
140+ gaugepart = project_antihermitian! (U' * ΔU + Vᴴ * ΔVᴴ' )
143141 for (c, b) in blocks (gaugepart)
144142 Sd = diagview (block (S, c))
145143 # for some reason this fails only on tests, and I cannot reproduce it in an
@@ -153,8 +151,6 @@ function remove_svdgauge_dependence!(
153151 return ΔU, ΔVᴴ
154152end
155153
156- project_hermitian (A) = (A + A' ) / 2
157-
158154# Tests
159155# -----
160156
@@ -601,7 +597,7 @@ for V in spacelist
601597 USVᴴ_trunc = svd_trunc (t; trunc)
602598 ΔUSVᴴ_trunc = rand_tangent .(USVᴴ_trunc)
603599 remove_svdgauge_dependence! (
604- ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], USVᴴ_trunc... ; degeneracy_atol
600+ ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], Base . front ( USVᴴ_trunc) ... ; degeneracy_atol
605601 )
606602 # test_ad_rrule(svd_trunc, t;
607603 # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol)
@@ -610,7 +606,7 @@ for V in spacelist
610606 USVᴴ_trunc = svd_trunc (t; trunc)
611607 ΔUSVᴴ_trunc = rand_tangent .(USVᴴ_trunc)
612608 remove_svdgauge_dependence! (
613- ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], USVᴴ_trunc... ; degeneracy_atol
609+ ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], Base . front ( USVᴴ_trunc) ... ; degeneracy_atol
614610 )
615611 test_ad_rrule (
616612 svd_trunc, t;
@@ -631,7 +627,7 @@ for V in spacelist
631627 USVᴴ_trunc = svd_trunc (t; trunc)
632628 ΔUSVᴴ_trunc = rand_tangent .(USVᴴ_trunc)
633629 remove_svdgauge_dependence! (
634- ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], USVᴴ_trunc... ; degeneracy_atol
630+ ΔUSVᴴ_trunc[1 ], ΔUSVᴴ_trunc[3 ], Base . front ( USVᴴ_trunc) ... ; degeneracy_atol
635631 )
636632 test_ad_rrule (
637633 svd_trunc, t;
0 commit comments