Skip to content

Commit 5ed33f1

Browse files
authored
Merge pull request #374 from JuliaDiff/mz/structured
structured.jl autotangent
2 parents 24318b0 + 469e8ad commit 5ed33f1

File tree

1 file changed

+22
-46
lines changed

1 file changed

+22
-46
lines changed

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,23 @@
22
@testset "/ and \\ on Square Matrixes" begin
33
@testset "//, $T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
44
RHS = T(randn(T == Diagonal ? 10 : (10, 10)))
5-
Y = randn(5, 10)
6-
= randn(size(/(Y, RHS))...)
7-
rrule_test(/, Ȳ, (Y, randn(size(Y))), (RHS, randn(size(RHS))))
5+
test_rrule(/, randn(5, 10), RHS)
86
end
97

108
@testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
119
LHS = T(randn(T == Diagonal ? 10 : (10, 10)))
12-
y = randn(10)
13-
= randn(size(\(LHS, y))...)
14-
rrule_test(\, ȳ, (LHS, randn(size(LHS))), (y, randn(10)))
15-
Y = randn(10, 10)
16-
= randn(10, 10)
17-
rrule_test(\, Ȳ, (LHS, randn(size(LHS))), (Y, randn(size(Y))))
10+
test_rrule(\, LHS, randn(10))
11+
test_rrule(\, LHS, randn(10, 10))
1812
end
1913
end
2014

2115
@testset "Diagonal" begin
2216
N = 3
23-
rrule_test(Diagonal, randn(N, N), (randn(N), randn(N)))
24-
D = Diagonal(randn( N))
25-
rrule_test(Diagonal, D, (randn(N), randn(N)))
17+
test_rrule(Diagonal, randn(N); output_tangent=randn(N, N))
18+
D = Diagonal(randn(N))
19+
test_rrule(Diagonal, randn(N); output_tangent=D)
2620
# Concrete type instead of UnionAll
27-
rrule_test(typeof(D), D, (randn(N), randn(N)))
21+
test_rrule(typeof(D), randn(N); output_tangent=D)
2822

2923
# TODO: replace this with a `rrule_test` once we have that working
3024
# see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24
@@ -35,30 +29,20 @@
3529
end
3630
@testset "dot(x, ::Diagonal, y)" begin
3731
N = 4
38-
x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N)
39-
x̄, d̄, ȳ = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N)
40-
D = Diagonal(d)
41-
= Diagonal(d̄)
42-
rrule_test(dot, rand(ComplexF64), (x, x̄), (D, D̄), (y, ȳ))
32+
test_rrule(dot, randn(ComplexF64, N), Diagonal(randn(ComplexF64, N)), randn(ComplexF64, N))
4333
end
4434
@testset "::Diagonal * ::AbstractVector" begin
4535
N = 3
46-
rrule_test(
47-
*,
48-
randn(N),
49-
(Diagonal(randn(N)), Diagonal(randn(N))),
50-
(randn(N), randn(N)),
51-
)
36+
test_rrule(*, Diagonal(randn(N)), randn(N))
5237
end
5338
@testset "diag" begin
5439
N = 7
55-
rrule_test(diag, randn(N), (randn(N, N), randn(N, N)))
56-
rrule_test(diag, randn(N), (Diagonal(randn(N)), randn(N, N)))
57-
rrule_test(diag, randn(N), (randn(N, N), Diagonal(randn(N))))
58-
rrule_test(diag, randn(N), (Diagonal(randn(N)), Diagonal(randn(N))))
40+
test_rrule(diag, randn(N, N))
41+
test_rrule(diag, Diagonal(randn(N)))
42+
test_rrule(diag, randn(N, N) Diagonal(randn(N)))
43+
test_rrule(diag, Diagonal(randn(N)) Diagonal(randn(N)))
5944
VERSION v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
60-
M = N - abs(k)
61-
rrule_test(diag, randn(M), (randn(N, N), randn(N, N)), (k, nothing))
45+
test_rrule(diag, randn(N, N), k nothing)
6246
end
6347
end
6448
@testset "diagm" begin
@@ -112,25 +96,23 @@
11296
m = 3
11397
@testset "$f(::Matrix{$T})" begin
11498
A = randn(T, n, m)
115-
= randn(T, n, m)
11699
Y = f(A)
117100
Ȳ_mat = randn(T, m, n)
118101
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
119102

120-
rrule_test(f, Ȳ_mat, (A, Ā))
103+
test_rrule(f, A; output_tangent=Ȳ_mat)
121104

122105
_, pb = rrule(f, A)
123106
@test pb(Ȳ_mat) == pb(Ȳ_composite)
124107
end
125108

126109
@testset "$f(::Vector{$T})" begin
127110
a = randn(T, n)
128-
= randn(T, n)
129111
y = f(a)
130112
ȳ_mat = randn(T, 1, n)
131113
ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))
132114

133-
rrule_test(f, ȳ_mat, (a, ā))
115+
test_rrule(f, a; output_tangent=ȳ_mat)
134116

135117
_, pb = rrule(f, a)
136118
@test pb(ȳ_mat) == pb(ȳ_composite)
@@ -139,30 +121,24 @@
139121
@testset "$f(::Adjoint{$T, Vector{$T})" begin
140122
a = randn(T, n)'
141123
= randn(T, n)'
142-
y = f(a)
143-
= randn(T, n)
144-
145-
rrule_test(f, ȳ, (a, ā))
124+
test_rrule(f, a ā; output_tangent=randn(T, n))
146125
end
147126

148127
@testset "$f(::Transpose{$T, Vector{$T})" begin
149128
a = transpose(randn(T, n))
150129
= transpose(randn(T, n))
151-
y = f(a)
152-
= randn(T, n)
153-
154-
rrule_test(f, ȳ, (a, ā))
130+
test_rrule(f, a ā; output_tangent=randn(T, n))
155131
end
156132
end
157133
@testset "$T" for T in (UpperTriangular, LowerTriangular)
158134
n = 5
159-
rrule_test(T, T(randn(n, n)), (randn(n, n), randn(n, n)))
135+
test_rrule(T, randn(n, n); output_tangent=T(randn(n, n)))
160136
end
161137
@testset "$Op" for Op in (triu, tril)
162138
n = 7
163-
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)))
139+
test_rrule(Op, randn(n, n))
164140
@testset "k=$k" for k in -2:2
165-
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing))
141+
test_rrule(Op, randn(n, n), k nothing)
166142
end
167143
end
168144

@@ -173,7 +149,7 @@
173149
# rand (not randn) so det will be postive, so logdet will be defined
174150
X = S(3*rand(T, (n, n)) .+ 1)
175151
X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
176-
rrule_test(op, rand(T), (X, X̄_acc))
152+
test_rrule(op, X X̄_acc)
177153
end
178154
@testset "return type" begin
179155
X = S(3*rand(6, 6) .+ 1)

0 commit comments

Comments
 (0)