Skip to content

Commit da86f12

Browse files
author
Miha Zgubic
committed
fix tests
1 parent 6cbe36f commit da86f12

File tree

1 file changed

+20
-32
lines changed

1 file changed

+20
-32
lines changed

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
@testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
99
LHS = T(randn(T == Diagonal ? 10 : (10, 10)))
10-
test_rrule(\, LHS, y = randn(10))
11-
test_rrule(\, LHS, y = randn(10, 10))
10+
test_rrule(\, LHS, randn(10))
11+
test_rrule(\, LHS, randn(10, 10))
1212
end
1313
end
1414

1515
@testset "Diagonal" begin
1616
N = 3
17-
test_rrule(Diagonal, randn(N))
18-
D = Diagonal(randn( N))
17+
test_rrule(Diagonal, randn(N); output_tangent=randn(N, N))
18+
D = Diagonal(randn(N))
1919
test_rrule(Diagonal, randn(N); output_tangent=D)
2020
# Concrete type instead of UnionAll
2121
test_rrule(typeof(D), randn(N); output_tangent=D)
@@ -29,33 +29,23 @@
2929
end
3030
@testset "dot(x, ::Diagonal, y)" begin
3131
N = 4
32-
x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N)
33-
x̄, d̄, ȳ = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N)
34-
D = Diagonal(d)
35-
= Diagonal(d̄)
36-
rrule_test(dot, rand(ComplexF64), (x, x̄), (D, D̄), (y, ȳ))
32+
test_rrule(dot, randn(ComplexF64, N), Diagonal(randn(ComplexF64, N)), randn(ComplexF64, N))
3733
end
3834
@testset "::Diagonal * ::AbstractVector" begin
3935
N = 3
40-
rrule_test(
41-
*,
42-
randn(N),
43-
(Diagonal(randn(N)), Diagonal(randn(N))),
44-
(randn(N), randn(N)),
45-
)
36+
test_rrule(*, Diagonal(randn(N)), randn(N))
4637
end
4738
@testset "diag" begin
4839
N = 7
49-
rrule_test(diag, randn(N), (randn(N, N), randn(N, N)))
50-
rrule_test(diag, randn(N), (Diagonal(randn(N)), randn(N, N)))
51-
rrule_test(diag, randn(N), (randn(N, N), Diagonal(randn(N))))
52-
rrule_test(diag, randn(N), (Diagonal(randn(N)), Diagonal(randn(N))))
40+
test_rrule(diag, randn(N, N))
41+
test_rrule(diag, Diagonal(randn(N)))
42+
test_rrule(diag, randn(N, N) Diagonal(randn(N)))
43+
test_rrule(diag, Diagonal(randn(N)) Diagonal(randn(N)))
5344
VERSION v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
54-
M = N - abs(k)
55-
rrule_test(diag, randn(M), (randn(N, N), randn(N, N)), (k, nothing))
45+
test_rrule(diag, randn(N, N), k nothing)
5646
end
5747
end
58-
@testset "diagm" begin
48+
@testset "diagm" begin # TODO review testset
5949
@testset "without size" begin
6050
M, N = 7, 9
6151
s = (8, 8)
@@ -106,25 +96,23 @@
10696
m = 3
10797
@testset "$f(::Matrix{$T})" begin
10898
A = randn(T, n, m)
109-
= randn(T, n, m)
11099
Y = f(A)
111100
Ȳ_mat = randn(T, m, n)
112101
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
113102

114-
rrule_test(f, Ȳ_mat, (A, Ā))
103+
test_rrule(f, A; output_tangent=Ȳ_mat)
115104

116105
_, pb = rrule(f, A)
117106
@test pb(Ȳ_mat) == pb(Ȳ_composite)
118107
end
119108

120109
@testset "$f(::Vector{$T})" begin
121110
a = randn(T, n)
122-
= randn(T, n)
123111
y = f(a)
124112
ȳ_mat = randn(T, 1, n)
125113
ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))
126114

127-
rrule_test(f, ȳ_mat, (a, ā))
115+
test_rrule(f, a; output_tangent=ȳ_mat)
128116

129117
_, pb = rrule(f, a)
130118
@test pb(ȳ_mat) == pb(ȳ_composite)
@@ -136,7 +124,7 @@
136124
y = f(a)
137125
= randn(T, n)
138126

139-
rrule_test(f, ȳ, (a, ā))
127+
test_rrule(f, a ā; output_tangent=)
140128
end
141129

142130
@testset "$f(::Transpose{$T, Vector{$T})" begin
@@ -145,18 +133,18 @@
145133
y = f(a)
146134
= randn(T, n)
147135

148-
rrule_test(f, ȳ, (a, ā))
136+
test_rrule(f, a ā; output_tangent=)
149137
end
150138
end
151139
@testset "$T" for T in (UpperTriangular, LowerTriangular)
152140
n = 5
153-
rrule_test(T, T(randn(n, n)), (randn(n, n), randn(n, n)))
141+
test_rrule(T, randn(n, n); output_tangent=T(randn(n, n)))
154142
end
155143
@testset "$Op" for Op in (triu, tril)
156144
n = 7
157-
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)))
145+
test_rrule(Op, randn(n, n))
158146
@testset "k=$k" for k in -2:2
159-
rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing))
147+
test_rrule(Op, randn(n, n), k nothing)
160148
end
161149
end
162150

@@ -167,7 +155,7 @@
167155
# rand (not randn) so det will be postive, so logdet will be defined
168156
X = S(3*rand(T, (n, n)) .+ 1)
169157
X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
170-
rrule_test(op, rand(T), (X, X̄_acc))
158+
test_rrule(op, X X̄_acc)
171159
end
172160
@testset "return type" begin
173161
X = S(3*rand(6, 6) .+ 1)

0 commit comments

Comments
 (0)