@@ -11,21 +11,23 @@ include("ad_utils.jl")
1111for f in
1212 (
1313 :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
14- :eig_full, :eig_trunc, :eigh_full, :eigh_trunc, :svd_compact, :svd_trunc,
14+ :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
15+ :svd_compact, :svd_trunc, :svd_vals,
1516 :left_polar, :right_polar,
1617 )
1718 copy_f = Symbol(:copy_, f)
1819 f! = Symbol(f, ' !' )
20+ _hermitian = startswith(string(f), " eigh" )
1921 @eval begin
2022 function $ copy_f(input, alg)
21- if $ f === eigh_full || $ f === eigh_trunc
23+ if $ _hermitian
2224 input = (input + input' ) / 2
2325 end
2426 return $f(input, alg)
2527 end
2628 function ChainRulesCore.rrule(::typeof($copy_f), input, alg)
2729 output = MatrixAlgebraKit.initialize_output($f!, input, alg)
28- if $f === eigh_full || $f === eigh_trunc
30+ if $_hermitian
2931 input = (input + input' ) / 2
3032 else
3133 input = copy(input)
@@ -228,12 +230,13 @@ end
228230 ΔD2 = Diagonal(randn(rng, complex(T), m))
229231 for alg in (LAPACK_Simple(), LAPACK_Expert())
230232 test_rrule(
231- copy_eig_full, A, alg ⊢ NoTangent();
232- output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol
233+ copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
233234 )
234235 test_rrule(
235- copy_eig_full, A, alg ⊢ NoTangent();
236- output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol
236+ copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
237+ )
238+ test_rrule(
239+ copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
237240 )
238241 for r in 1 : 4 : m
239242 truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
284287 config, last ∘ eig_full, A;
285288 output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
286289 )
290+ test_rrule(
291+ config, eig_vals, A;
292+ output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
293+ )
287294end
288295
289296@timedtestset " EIGH AD Rules with eltype $T " for T in (Float64, ComplexF64, Float32)
@@ -304,12 +311,13 @@ end
304311 )
305312 # copy_eigh_full includes a projector onto the Hermitian part of the matrix
306313 test_rrule(
307- copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV),
308- atol = atol, rtol = rtol
314+ copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
309315 )
310316 test_rrule(
311- copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV),
312- atol = atol, rtol = rtol
317+ copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
318+ )
319+ test_rrule(
320+ copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
313321 )
314322 for r in 1:4:m
315323 truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
361369 config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A;
362370 output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
363371 )
372+ test_rrule(
373+ config, eigh_vals ∘ Matrix ∘ Hermitian, A;
374+ output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
375+ )
364376 eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...)
365377 for r in 1:4:m
366378 trunc = truncrank(r; by = real)
404416 copy_svd_compact, A, alg ⊢ NoTangent();
405417 output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol
406418 )
419+ test_rrule(
420+ copy_svd_vals, A, alg ⊢ NoTangent();
421+ output_tangent = diagview(ΔS), atol, rtol
422+ )
407423 for r in 1:4:minmn
408424 truncalg = TruncatedAlgorithm(alg, truncrank(r))
409425 ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
451467 output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol,
452468 rrule_f = rrule_via_ad, check_inferred = false
453469 )
470+ test_rrule(
471+ config, svd_vals, A;
472+ output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
473+ )
454474 for r in 1:4:minmn
455475 trunc = truncrank(r)
456476 ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)
0 commit comments