|
2 | 2 | @testset "/ and \\ on Square Matrixes" begin
|
3 | 3 | @testset "//, $T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
|
4 | 4 | RHS = T(randn(T == Diagonal ? 10 : (10, 10)))
|
5 |
| - Y = randn(5, 10) |
6 |
| - Ȳ = randn(size(/(Y, RHS))...) |
7 |
| - rrule_test(/, Ȳ, (Y, randn(size(Y))), (RHS, randn(size(RHS)))) |
| 5 | + test_rrule(/, randn(5, 10), RHS) |
8 | 6 | end
|
9 | 7 |
|
10 | 8 | @testset "\\ $T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
|
11 | 9 | LHS = T(randn(T == Diagonal ? 10 : (10, 10)))
|
12 |
| - y = randn(10) |
13 |
| - ȳ = randn(size(\(LHS, y))...) |
14 |
| - rrule_test(\, ȳ, (LHS, randn(size(LHS))), (y, randn(10))) |
15 |
| - Y = randn(10, 10) |
16 |
| - Ȳ = randn(10, 10) |
17 |
| - rrule_test(\, Ȳ, (LHS, randn(size(LHS))), (Y, randn(size(Y)))) |
| 10 | + test_rrule(\, LHS, randn(10)) |
| 11 | + test_rrule(\, LHS, randn(10, 10)) |
18 | 12 | end
|
19 | 13 | end
|
20 | 14 |
|
21 | 15 | @testset "Diagonal" begin
|
22 | 16 | N = 3
|
23 |
| - rrule_test(Diagonal, randn(N, N), (randn(N), randn(N))) |
24 |
| - D = Diagonal(randn( N)) |
25 |
| - rrule_test(Diagonal, D, (randn(N), randn(N))) |
| 17 | + test_rrule(Diagonal, randn(N); output_tangent=randn(N, N)) |
| 18 | + D = Diagonal(randn(N)) |
| 19 | + test_rrule(Diagonal, randn(N); output_tangent=D) |
26 | 20 | # Concrete type instead of UnionAll
|
27 |
| - rrule_test(typeof(D), D, (randn(N), randn(N))) |
| 21 | + test_rrule(typeof(D), randn(N); output_tangent=D) |
28 | 22 |
|
29 | 23 | # TODO: replace this with a `rrule_test` once we have that working
|
30 | 24 | # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24
|
|
35 | 29 | end
|
36 | 30 | @testset "dot(x, ::Diagonal, y)" begin
|
37 | 31 | N = 4
|
38 |
| - x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N) |
39 |
| - x̄, d̄, ȳ = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N) |
40 |
| - D = Diagonal(d) |
41 |
| - D̄ = Diagonal(d̄) |
42 |
| - rrule_test(dot, rand(ComplexF64), (x, x̄), (D, D̄), (y, ȳ)) |
| 32 | + test_rrule(dot, randn(ComplexF64, N), Diagonal(randn(ComplexF64, N)), randn(ComplexF64, N)) |
43 | 33 | end
|
44 | 34 | @testset "::Diagonal * ::AbstractVector" begin
|
45 | 35 | N = 3
|
46 |
| - rrule_test( |
47 |
| - *, |
48 |
| - randn(N), |
49 |
| - (Diagonal(randn(N)), Diagonal(randn(N))), |
50 |
| - (randn(N), randn(N)), |
51 |
| - ) |
| 36 | + test_rrule(*, Diagonal(randn(N)), randn(N)) |
52 | 37 | end
|
53 | 38 | @testset "diag" begin
|
54 | 39 | N = 7
|
55 |
| - rrule_test(diag, randn(N), (randn(N, N), randn(N, N))) |
56 |
| - rrule_test(diag, randn(N), (Diagonal(randn(N)), randn(N, N))) |
57 |
| - rrule_test(diag, randn(N), (randn(N, N), Diagonal(randn(N)))) |
58 |
| - rrule_test(diag, randn(N), (Diagonal(randn(N)), Diagonal(randn(N)))) |
| 40 | + test_rrule(diag, randn(N, N)) |
| 41 | + test_rrule(diag, Diagonal(randn(N))) |
| 42 | + test_rrule(diag, randn(N, N) ⊢ Diagonal(randn(N))) |
| 43 | + test_rrule(diag, Diagonal(randn(N)) ⊢ Diagonal(randn(N))) |
59 | 44 | VERSION ≥ v"1.3" && @testset "k=$k" for k in (-1, 0, 2)
|
60 |
| - M = N - abs(k) |
61 |
| - rrule_test(diag, randn(M), (randn(N, N), randn(N, N)), (k, nothing)) |
| 45 | + test_rrule(diag, randn(N, N), k ⊢ nothing) |
62 | 46 | end
|
63 | 47 | end
|
64 | 48 | @testset "diagm" begin
|
|
112 | 96 | m = 3
|
113 | 97 | @testset "$f(::Matrix{$T})" begin
|
114 | 98 | A = randn(T, n, m)
|
115 |
| - Ā = randn(T, n, m) |
116 | 99 | Y = f(A)
|
117 | 100 | Ȳ_mat = randn(T, m, n)
|
118 | 101 | Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
|
119 | 102 |
|
120 |
| - rrule_test(f, Ȳ_mat, (A, Ā)) |
| 103 | + test_rrule(f, A; output_tangent=Ȳ_mat) |
121 | 104 |
|
122 | 105 | _, pb = rrule(f, A)
|
123 | 106 | @test pb(Ȳ_mat) == pb(Ȳ_composite)
|
124 | 107 | end
|
125 | 108 |
|
126 | 109 | @testset "$f(::Vector{$T})" begin
|
127 | 110 | a = randn(T, n)
|
128 |
| - ā = randn(T, n) |
129 | 111 | y = f(a)
|
130 | 112 | ȳ_mat = randn(T, 1, n)
|
131 | 113 | ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))
|
132 | 114 |
|
133 |
| - rrule_test(f, ȳ_mat, (a, ā)) |
| 115 | + test_rrule(f, a; output_tangent=ȳ_mat) |
134 | 116 |
|
135 | 117 | _, pb = rrule(f, a)
|
136 | 118 | @test pb(ȳ_mat) == pb(ȳ_composite)
|
|
139 | 121 | @testset "$f(::Adjoint{$T, Vector{$T})" begin
|
140 | 122 | a = randn(T, n)'
|
141 | 123 | ā = randn(T, n)'
|
142 |
| - y = f(a) |
143 |
| - ȳ = randn(T, n) |
144 |
| - |
145 |
| - rrule_test(f, ȳ, (a, ā)) |
| 124 | + test_rrule(f, a ⊢ ā; output_tangent=randn(T, n)) |
146 | 125 | end
|
147 | 126 |
|
148 | 127 | @testset "$f(::Transpose{$T, Vector{$T})" begin
|
149 | 128 | a = transpose(randn(T, n))
|
150 | 129 | ā = transpose(randn(T, n))
|
151 |
| - y = f(a) |
152 |
| - ȳ = randn(T, n) |
153 |
| - |
154 |
| - rrule_test(f, ȳ, (a, ā)) |
| 130 | + test_rrule(f, a ⊢ ā; output_tangent=randn(T, n)) |
155 | 131 | end
|
156 | 132 | end
|
157 | 133 | @testset "$T" for T in (UpperTriangular, LowerTriangular)
|
158 | 134 | n = 5
|
159 |
| - rrule_test(T, T(randn(n, n)), (randn(n, n), randn(n, n))) |
| 135 | + test_rrule(T, randn(n, n); output_tangent=T(randn(n, n))) |
160 | 136 | end
|
161 | 137 | @testset "$Op" for Op in (triu, tril)
|
162 | 138 | n = 7
|
163 |
| - rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n))) |
| 139 | + test_rrule(Op, randn(n, n)) |
164 | 140 | @testset "k=$k" for k in -2:2
|
165 |
| - rrule_test(Op, randn(n, n), (randn(n, n), randn(n, n)), (k, nothing)) |
| 141 | + test_rrule(Op, randn(n, n), k ⊢ nothing) |
166 | 142 | end
|
167 | 143 | end
|
168 | 144 |
|
|
173 | 149 | # rand (not randn) so det will be postive, so logdet will be defined
|
174 | 150 | X = S(3*rand(T, (n, n)) .+ 1)
|
175 | 151 | X̄_acc = Diagonal(rand(T, (n, n))) # sensitivity is always a diagonal for these types
|
176 |
| - rrule_test(op, rand(T), (X, X̄_acc)) |
| 152 | + test_rrule(op, X ⊢ X̄_acc) |
177 | 153 | end
|
178 | 154 | @testset "return type" begin
|
179 | 155 | X = S(3*rand(6, 6) .+ 1)
|
|
0 commit comments