Skip to content

Commit 5e8ed01

Browse files
committed
fix copy_input signatures in Mooncake tests
1 parent 21d225e commit 5e8ed01

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

test/mooncake.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,20 @@ end
504504
end
505505
end
506506

507+
left_orth_qr(X) = left_orth(X; alg = :qr)
508+
left_orth_polar(X) = left_orth(X; alg = :polar)
509+
left_null_qr(X) = left_null(X; alg = :qr)
510+
right_orth_lq(X) = right_orth(X; alg = :lq)
511+
right_orth_polar(X) = right_orth(X; alg = :polar)
512+
right_null_lq(X) = right_null(X; alg = :lq)
513+
514+
MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
515+
MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
516+
MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
517+
MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
518+
MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A)
519+
MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
520+
507521
@timedtestset "Orth and null with eltype $T" for T in ETs
508522
rng = StableRNG(12345)
509523
m = 19
@@ -517,30 +531,19 @@ end
517531
Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
518532
test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)))
519533

520-
left_orth_qr(X) = left_orth(X; alg = :qr)
521-
left_orth_polar(X) = left_orth(X; alg = :polar)
522-
MatrixAlgebraKit.copy_input(left_orth_qr, A) = MatrixAlgebraKit.copy_input(left_orth, A)
523-
MatrixAlgebraKit.copy_input(left_orth_polar, A) = MatrixAlgebraKit.copy_input(left_orth, A)
524-
525534
Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
526535
test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...)))
527536
if m >= n
528537
Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
529538
test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...)))
530539
end
531540

532-
left_null_qr(X) = left_null(X; alg = :qr)
533-
MatrixAlgebraKit.copy_input(left_null_qr, A) = MatrixAlgebraKit.copy_input(left_null, A)
534541
N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n))
535542
ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n))
536543
dN = make_mooncake_tangent(ΔN)
537544
Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN)
538545
test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN)
539546

540-
right_orth_lq(X) = right_orth(X; alg = :lq)
541-
right_orth_polar(X) = right_orth(X; alg = :polar)
542-
MatrixAlgebraKit.copy_input(right_orth_lq, A) = MatrixAlgebraKit.copy_input(right_orth, A)
543-
MatrixAlgebraKit.copy_input(right_orth_polar, A) = MatrixAlgebraKit.copy_input(right_orth, A)
544547
Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
545548
test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)))
546549

@@ -549,8 +552,6 @@ end
549552
test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)))
550553
end
551554

552-
right_null_lq(X) = right_null(X; alg = :lq)
553-
MatrixAlgebraKit.copy_input(right_null_lq, A) = MatrixAlgebraKit.copy_input(right_null, A)
554555
Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2]
555556
ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2]
556557
dNᴴ = make_mooncake_tangent(ΔNᴴ)

0 commit comments

Comments
 (0)