Skip to content

Commit c52a00f

Browse files
committed
matfun.jl autotangent
1 parent 5b13ffb commit c52a00f

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

test/rulesets/LinearAlgebra/matfun.jl

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,18 @@
66
nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
77

88
A = randn(T, n, n)
9-
ΔA = randn(T, n, n)
109
A *= nrm / opnorm(A, 1)
1110
tols = nrm == 0.1 ? (atol=1e-8, rtol=1e-8) : NamedTuple()
12-
frule_test(LinearAlgebra.exp!, (A, ΔA); tols...)
11+
test_frule(LinearAlgebra.exp!, A; tols...)
1312
end
1413
@testset "imbalanced A" begin
1514
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
16-
ΔA = rand_tangent(A)
17-
frule_test(LinearAlgebra.exp!, (A, ΔA))
15+
test_frule(LinearAlgebra.exp!, A)
1816
end
1917
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
2018
A = Matrix(Hermitian(randn(T, n, n)))
21-
ΔA = randn(T, n, n)
22-
frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA))))
23-
frule_test(LinearAlgebra.exp!, (A, ΔA))
19+
test_frule(LinearAlgebra.exp!, A)
20+
test_frule(LinearAlgebra.exp!, A Matrix(Hermitian(randn(T, n, n))))
2421
end
2522
end
2623

@@ -31,27 +28,24 @@
3128
nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
3229

3330
A = randn(T, n, n)
34-
ΔA = randn(T, n, n)
35-
ΔY = randn(T, n, n)
3631
A *= nrm / opnorm(A, 1)
3732
# rrule is not inferrable, but pullback should be
3833
tols = nrm == 0.1 ? (atol=1e-8, rtol=1e-8) : NamedTuple()
39-
rrule_test(exp, ΔY, (A, ΔA); check_inferred=false, tols...)
34+
test_rrule(exp, A; check_inferred=false, tols...)
4035
Y, back = rrule(exp, A)
41-
@inferred back(ΔY)
36+
@inferred back(rand_tangent(Y))
4237
end
4338
@testset "imbalanced A" begin
4439
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
45-
ΔA = rand_tangent(A)
46-
ΔY = rand_tangent(exp(A))
47-
rrule_test(exp, ΔY, (A, ΔA); check_inferred=false)
40+
test_rrule(exp, A; check_inferred=false)
4841
end
4942
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
5043
A = Matrix(Hermitian(randn(T, n, n)))
51-
ΔA = randn(T, n, n)
52-
ΔY = randn(T, n, n)
53-
rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred=false)
54-
rrule_test(exp, ΔY, (A, ΔA); check_inferred=false)
44+
test_rrule(exp, A; check_inferred=false)
45+
test_rrule(
46+
exp, A;
47+
check_inferred=false, output_tangent=Matrix(Hermitian(randn(T, n, n)))
48+
)
5549
end
5650
end
5751
end

0 commit comments

Comments
 (0)