Skip to content

Commit 7de8fe2

Browse files
committed
fix copy_input signatures in Mooncake tests
1 parent 21d225e commit 7de8fe2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

test/mooncake.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,8 @@ end
519519

520520
left_orth_qr(X) = left_orth(X; alg = :qr)
521521
left_orth_polar(X) = left_orth(X; alg = :polar)
522-
MatrixAlgebraKit.copy_input(left_orth_qr, A) = MatrixAlgebraKit.copy_input(left_orth, A)
523-
MatrixAlgebraKit.copy_input(left_orth_polar, A) = MatrixAlgebraKit.copy_input(left_orth, A)
522+
MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
523+
MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
524524

525525
Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
526526
test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...)))
@@ -530,7 +530,7 @@ end
530530
end
531531

532532
left_null_qr(X) = left_null(X; alg = :qr)
533-
MatrixAlgebraKit.copy_input(left_null_qr, A) = MatrixAlgebraKit.copy_input(left_null, A)
533+
MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
534534
N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n))
535535
ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n))
536536
dN = make_mooncake_tangent(ΔN)
@@ -539,8 +539,8 @@ end
539539

540540
right_orth_lq(X) = right_orth(X; alg = :lq)
541541
right_orth_polar(X) = right_orth(X; alg = :polar)
542-
MatrixAlgebraKit.copy_input(right_orth_lq, A) = MatrixAlgebraKit.copy_input(right_orth, A)
543-
MatrixAlgebraKit.copy_input(right_orth_polar, A) = MatrixAlgebraKit.copy_input(right_orth, A)
542+
MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
543+
MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A)
544544
Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false)
545545
test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...)))
546546

@@ -550,7 +550,7 @@ end
550550
end
551551

552552
right_null_lq(X) = right_null(X; alg = :lq)
553-
MatrixAlgebraKit.copy_input(right_null_lq, A) = MatrixAlgebraKit.copy_input(right_null, A)
553+
MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
554554
Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2]
555555
ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2]
556556
dNᴴ = make_mooncake_tangent(ΔNᴴ)

0 commit comments

Comments
 (0)