Skip to content

Commit 0023450

Browse files
committed
WIP: symmetric.jl autotangent
1 parent bd46c0a commit 0023450

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

test/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,39 @@
44
T in (Float64, ComplexF64),
55
uplo in (:U, :L)
66

7-
N = 3
87
@testset "frule" begin
9-
x = randn(T, N, N)
10-
Δx = randn(T, N, N)
11-
# can't use frule_test here because it doesn't yet ignore nothing tangents
12-
Ω = SymHerm(x, uplo)
13-
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
14-
@test Ω_ad == Ω
15-
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
16-
@test ∂Ω_ad ∂Ω_fd
8+
test_frule(SymHerm, rand(T, 3, 3), uplo nothing)
179
end
1810
@testset "rrule" begin
1911
# on old versions of julia this combination doesn't infer but we don't care as
2012
# it infers fine on modern versions.
2113
check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
2214

23-
x = randn(T, N, N)
24-
∂x = randn(T, N, N)
25-
ΔΩ = randn(T, N, N)
2615
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
27-
rrule_test(
28-
SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing);
16+
x = randn(T, 3, 3)
17+
ΔΩ = MT(randn(T, 3, 3))
18+
test_rrule(
19+
SymHerm, x, uplo nothing;
20+
output_tangent = ΔΩ,
2921
# type stability here critically relies on uplo being constant propagated,
3022
# so we need to test this more carefully below
3123
check_inferred=false,
3224
)
3325
if check_inferred
34-
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
26+
@inferred (function (SymHerm, x, ΔΩ, ::Val)
3527
return rrule(SymHerm, x, uplo)[2](ΔΩ)
36-
end)(SymHerm, x, MT(ΔΩ), Val(uplo))
28+
end)(SymHerm, x, ΔΩ, Val(uplo))
3729
end
3830
end
3931
@testset "back(::Diagonal)" begin
40-
rrule_test(
32+
x = randn(T, 3, 3)
33+
ΔΩ = Diagonal(randn(T, 3, 3))
34+
test_rrule(
4135
SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
4236
check_inferred=false,
4337
)
4438
if check_inferred
45-
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
39+
@inferred (function (SymHerm, x, ΔΩ, ::Val)
4640
return rrule(SymHerm, x, uplo)[2](ΔΩ)
4741
end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo))
4842
end
@@ -53,15 +47,13 @@
5347
SymHerm in (Symmetric, Hermitian),
5448
T in (Float64, ComplexF64),
5549
uplo in (:U, :L)
56-
57-
N = 3
58-
x = SymHerm(randn(T, N, N), uplo)
59-
Δx = randn(T, N, N)
60-
∂x = SymHerm(randn(T, N, N), uplo)
61-
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
62-
frule_test(f, (x, Δx))
63-
frule_test(f, (x, SymHerm(Δx, uplo)))
64-
rrule_test(f, ΔΩ, (x, ∂x))
50+
x = SymHerm(randn(T, 3, 3), uplo)
51+
Δx = randn(T, 3, 3)
52+
∂x = SymHerm(randn(T, 3, 3), uplo)
53+
ΔΩ = f(SymHerm(randn(T, 3, 3), uplo))
54+
test_frule(f, (x, Δx))
55+
test_frule(f, (x, SymHerm(Δx, uplo)))
56+
test_rrule(f, ΔΩ, (x, ∂x))
6557
end
6658

6759
# symmetric/hermitian eigendecomposition follows the sign convention

0 commit comments

Comments
 (0)