|
287 | 287 | config, last ∘ eig_full, A; |
288 | 288 | output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false |
289 | 289 | ) |
| 290 | + test_rrule( |
| 291 | + config, eig_vals, A; |
| 292 | + output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false |
| 293 | + ) |
290 | 294 | end |
291 | 295 |
|
292 | 296 | @timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) |
|
365 | 369 | config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; |
366 | 370 | output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false |
367 | 371 | ) |
| 372 | + test_rrule( |
| 373 | + config, eigh_vals ∘ Matrix ∘ Hermitian, A; |
| 374 | + output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false |
| 375 | + ) |
368 | 376 | eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) |
369 | 377 | for r in 1:4:m |
370 | 378 | trunc = truncrank(r; by = real) |
|
459 | 467 | output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, |
460 | 468 | rrule_f = rrule_via_ad, check_inferred = false |
461 | 469 | ) |
| 470 | + test_rrule( |
| 471 | + config, svd_vals, A; |
| 472 | + output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false |
| 473 | + ) |
462 | 474 | for r in 1:4:minmn |
463 | 475 | trunc = truncrank(r) |
464 | 476 | ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) |
|
0 commit comments