|
7 | 7 |
|
8 | 8 | @testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
|
9 | 9 | LHS = T(randn(T == Diagonal ? 10 : (10, 10)))
|
10 |
| - test_rrule(\, LHS, y = randn(10)) |
11 |
| - test_rrule(\, LHS, y = randn(10, 10)) |
| 10 | + test_rrule(\, LHS, randn(10)) |
| 11 | + test_rrule(\, LHS, randn(10, 10)) |
12 | 12 | end
|
13 | 13 | end
|
14 | 14 |
|
15 | 15 | @testset "Diagonal" begin
|
16 | 16 | N = 3
|
17 |
| - test_rrule(Diagonal, randn(N)) |
18 |
| - D = Diagonal(randn( N)) |
| 17 | + test_rrule(Diagonal, randn(N); output_tangent=randn(N, N)) |
| 18 | + D = Diagonal(randn(N)) |
19 | 19 | test_rrule(Diagonal, randn(N); output_tangent=D)
|
20 | 20 | # Concrete type instead of UnionAll
|
21 | 21 | test_rrule(typeof(D), randn(N); output_tangent=D)
|
|
29 | 29 | end
|
30 | 30 | @testset "dot(x, ::Diagonal, y)" begin
|
31 | 31 | N = 4
|
32 |
| - x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N) |
33 |
| - x̄, d̄, ȳ = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N) |
34 |
| - D = Diagonal(d) |
35 |
| - D̄ = Diagonal(d̄) |
36 |
| - rrule_test(dot, rand(ComplexF64), (x, x̄), (D, D̄), (y, ȳ)) |
| 32 | + test_rrule(dot, randn(ComplexF64, N), Diagonal(randn(ComplexF64, N)), randn(ComplexF64, N)) |
37 | 33 | end
|
38 | 34 | @testset "::Diagonal * ::AbstractVector" begin
|
39 | 35 | N = 3
|
40 |
| - rrule_test( |
41 |
| - *, |
42 |
| - randn(N), |
43 |
| - (Diagonal(randn(N)), Diagonal(randn(N))), |
44 |
| - (randn(N), randn(N)), |
45 |
| - ) |
| 36 | + test_rrule(*, Diagonal(randn(N)), randn(N)) |
46 | 37 | end
|
47 | 38 | @testset "diag" begin
|
48 | 39 | N = 7
|
49 |
| - rrule_test(diag, randn(N), (randn(N, N), randn(N, N))) |
50 |
| - rrule_test(diag, randn(N), (Diagonal(randn(N)), randn(N, N))) |
51 |
| - rrule_test(diag, randn(N), (randn(N, N), Diagonal(randn(N)))) |
52 |
| - rrule_test(diag, randn(N), (Diagonal(randn(N)), Diagonal(randn(N)))) |
| 40 | + test_rrule(diag, randn(N, N)) |
| 41 | + test_rrule(diag, Diagonal(randn(N))) |
| 42 | + test_rrule(diag, randn(N, N) ⊢ Diagonal(randn(N))) |
| 43 | + test_rrule(diag, Diagonal(randn(N)) ⊢ Diagonal(randn(N))) |
53 | 44 | VERSION ≥ v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
|
54 |
| - M = N - abs(k) |
55 |
| - rrule_test(diag, randn(M), (randn(N, N), randn(N, N)), (k, nothing)) |
| 45 | + test_rrule(diag, randn(N, N), k ⊢ nothing) |
56 | 46 | end
|
57 | 47 | end
|
58 |
| - @testset "diagm" begin |
| 48 | + @testset "diagm" begin # TODO review testset |
59 | 49 | @testset "without size" begin
|
60 | 50 | M, N = 7, 9
|
61 | 51 | s = (8, 8)
|
|
106 | 96 | m = 3
|
107 | 97 | @testset "$f(::Matrix{$T})" begin
|
108 | 98 | A = randn(T, n, m)
|
109 |
| - Ā = randn(T, n, m) |
110 | 99 | Y = f(A)
|
111 | 100 | Ȳ_mat = randn(T, m, n)
|
112 | 101 | Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
|
113 | 102 |
|
114 |
| - rrule_test(f, Ȳ_mat, (A, Ā)) |
| 103 | + test_rrule(f, A; output_tangent=Ȳ_mat) |
115 | 104 |
|
116 | 105 | _, pb = rrule(f, A)
|
117 | 106 | @test pb(Ȳ_mat) == pb(Ȳ_composite)
|
118 | 107 | end
|
119 | 108 |
|
120 | 109 | @testset "$f(::Vector{$T})" begin
|
121 | 110 | a = randn(T, n)
|
122 |
| - ā = randn(T, n) |
123 | 111 | y = f(a)
|
124 | 112 | ȳ_mat = randn(T, 1, n)
|
125 | 113 | ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))
|
126 | 114 |
|
127 |
| - rrule_test(f, ȳ_mat, (a, ā)) |
| 115 | + test_rrule(f, a; output_tangent=ȳ_mat) |
128 | 116 |
|
129 | 117 | _, pb = rrule(f, a)
|
130 | 118 | @test pb(ȳ_mat) == pb(ȳ_composite)
|
|
136 | 124 | y = f(a)
|
137 | 125 | ȳ = randn(T, n)
|
138 | 126 |
|
139 |
| - rrule_test(f, ȳ, (a, ā)) |
| 127 | + test_rrule(f, a ⊢ ā; output_tangent=ȳ) |
140 | 128 | end
|
141 | 129 |
|
142 | 130 | @testset "$f(::Transpose{$T, Vector{$T})" begin
|
|
145 | 133 | y = f(a)
|
146 | 134 | ȳ = randn(T, n)
|
147 | 135 |
|
148 |
| - rrule_test(f, ȳ, (a, ā)) |
| 136 | + test_rrule(f, a ⊢ ā; output_tangent=ȳ) |
149 | 137 | end
|
150 | 138 | end
|
151 | 139 | @testset "$T" for T in (UpperTriangular, LowerTriangular)
|
152 | 140 | n = 5
|
153 |
| - rrule_test(T, T(randn(n, n)), (randn(n, n), randn(n, n))) |
| 141 | + test_rrule(T, randn(n, n); output_tangent=T(randn(n, n))) |
154 | 142 | end
|
155 | 143 | @testset "$Op" for Op in (triu, tril)
|
156 | 144 | n = 7
|
157 |
| - rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n))) |
| 145 | + test_rrule(Op, randn(n, n)) |
158 | 146 | @testset "k=$k" for k in -2:2
|
159 |
| - rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing)) |
| 147 | + test_rrule(Op, randn(n, n), k ⊢ nothing) |
160 | 148 | end
|
161 | 149 | end
|
162 | 150 |
|
|
167 | 155 | # rand (not randn) so det will be postive, so logdet will be defined
|
168 | 156 | X = S(3*rand(T, (n, n)) .+ 1)
|
169 | 157 | X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
|
170 |
| - rrule_test(op, rand(T), (X, X̄_acc)) |
| 158 | + test_rrule(op, X ⊢ X̄_acc) |
171 | 159 | end
|
172 | 160 | @testset "return type" begin
|
173 | 161 | X = S(3*rand(6, 6) .+ 1)
|
|
0 commit comments