|
519 | 519 |
|
520 | 520 | left_orth_qr(X) = left_orth(X; alg = :qr) |
521 | 521 | left_orth_polar(X) = left_orth(X; alg = :polar) |
522 | | - MatrixAlgebraKit.copy_input(left_orth_qr, A) = MatrixAlgebraKit.copy_input(left_orth, A) |
523 | | - MatrixAlgebraKit.copy_input(left_orth_polar, A) = MatrixAlgebraKit.copy_input(left_orth, A) |
| 522 | + MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) |
| 523 | + MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) |
524 | 524 |
|
525 | 525 | Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) |
526 | 526 | test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) |
|
530 | 530 | end |
531 | 531 |
|
532 | 532 | left_null_qr(X) = left_null(X; alg = :qr) |
533 | | - MatrixAlgebraKit.copy_input(left_null_qr, A) = MatrixAlgebraKit.copy_input(left_null, A) |
| 533 | + MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) |
534 | 534 | N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) |
535 | 535 | ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) |
536 | 536 | dN = make_mooncake_tangent(ΔN) |
|
539 | 539 |
|
540 | 540 | right_orth_lq(X) = right_orth(X; alg = :lq) |
541 | 541 | right_orth_polar(X) = right_orth(X; alg = :polar) |
542 | | - MatrixAlgebraKit.copy_input(right_orth_lq, A) = MatrixAlgebraKit.copy_input(right_orth, A) |
543 | | - MatrixAlgebraKit.copy_input(right_orth_polar, A) = MatrixAlgebraKit.copy_input(right_orth, A) |
| 542 | + MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) |
| 543 | + MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) |
544 | 544 | Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) |
545 | 545 | test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) |
546 | 546 |
|
|
550 | 550 | end |
551 | 551 |
|
552 | 552 | right_null_lq(X) = right_null(X; alg = :lq) |
553 | | - MatrixAlgebraKit.copy_input(right_null_lq, A) = MatrixAlgebraKit.copy_input(right_null, A) |
| 553 | + MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) |
554 | 554 | Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] |
555 | 555 | ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] |
556 | 556 | dNᴴ = make_mooncake_tangent(ΔNᴴ) |
|
0 commit comments