|
1 | 1 | @testset "arraymath" begin
|
2 | 2 | @testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
|
3 |
| - N = 3 |
4 |
| - B = generate_well_conditioned_matrix(T, N) |
5 |
| - frule_test(inv, (B, randn(T, N, N))) |
6 |
| - rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) |
| 3 | + B = generate_well_conditioned_matrix(T, 3) |
| 4 | + test_frule(inv, B) |
| 5 | + test_rrule(inv, B) |
7 | 6 | end
|
8 | 7 |
|
9 | 8 | @testset "*: $T" for T in (Float64, ComplexF64)
|
10 | 9 | ⋆(a) = round.(5*randn(T, a)) # Helper to generate nice random values
|
11 | 10 | ⋆(a, b) = ⋆((a, b)) # matrix
|
12 | 11 | ⋆() = only(⋆(())) # scalar
|
13 | 12 |
|
14 |
| - ⋆₂(a) = (⋆(a), ⋆(a)) # Helper to generate random matrix and its cotangent |
15 |
| - ⋆₂(a, b) = ⋆₂((a, b)) #matrix |
16 |
| - ⋆₂() = ⋆₂(()) # scalar |
17 |
| - |
18 | 13 | @testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5))
|
19 |
| - rrule_test(*, ⋆(dims), ⋆₂(), ⋆₂(dims)) |
20 |
| - rrule_test(*, ⋆(dims), ⋆₂(dims), ⋆₂()) |
| 14 | + test_rrule(*, ⋆(), ⋆(dims)) |
| 15 | + test_rrule(*, ⋆(dims), ⋆()) |
21 | 16 | end
|
22 | 17 |
|
23 | 18 | @testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5)
|
24 | 19 | @testset "Array" begin
|
25 |
| - rrule_test(*, ⋆(n), n ⋆₂ m, ⋆₂(m)) |
| 20 | + test_rrule(*, n ⋆ m, ⋆(m)) |
26 | 21 | end
|
27 | 22 | end
|
28 | 23 |
|
29 | 24 | @testset "AbstractVector-AbstractMatrix n=$n, m=$m" for n in (2, 3), m in (4, 5)
|
30 | 25 | @testset "Array" begin
|
31 |
| - rrule_test(*, n ⋆ m, ⋆₂(n), 1 ⋆₂ m) |
| 26 | + test_rrule(*, ⋆(n), 1 ⋆ m) |
32 | 27 | end
|
33 | 28 | end
|
34 | 29 |
|
35 | 30 | @testset "AbstractMatrix-AbstractMatrix" begin
|
36 | 31 | @testset "Matrix * Matrix n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3)
|
37 | 32 | @testset "Array" begin
|
38 |
| - rrule_test(*, n⋆p, (n⋆₂m), (m⋆₂p)) |
| 33 | + test_rrule(*, (n⋆m), (m⋆p)) |
39 | 34 | end
|
40 | 35 |
|
41 | 36 | @testset "SubArray - $indexname" for (indexname, m_index) in (
|
42 |
| - ("fast", :), ("slow", Ref(m:-1:1)) |
| 37 | + ("fast", :), ("slow", m:-1:1) |
43 | 38 | )
|
44 |
| - rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), view.(m⋆₂p, m_index, :)) |
45 |
| - rrule_test(*, n⋆p, n⋆₂m, view.(m⋆₂p, m_index, :)) |
46 |
| - rrule_test(*, n⋆p, view.(n⋆₂m, :, m_index), m⋆₂p) |
| 39 | + test_rrule(*, view(n⋆m, :, m_index), view(m⋆p, m_index, :)) |
| 40 | + test_rrule(*, n⋆m, view(m⋆p, m_index, :)) |
| 41 | + test_rrule(*, view(n⋆m, :, m_index), m⋆p) |
47 | 42 | end
|
48 | 43 |
|
49 | 44 | @testset "Adjoints and Transposes" begin
|
50 |
| - rrule_test(*, n⋆p, Transpose.(m⋆₂n), Transpose.(p⋆₂m)) |
51 |
| - rrule_test(*, n⋆p, Adjoint.(m⋆₂n), Adjoint.(p⋆₂m)) |
| 45 | + test_rrule(*, Transpose(m⋆n) ⊢ Transpose(m⋆n), Transpose(p⋆m) ⊢ Transpose(p⋆m)) |
| 46 | + test_rrule(*, Adjoint(m⋆n) ⊢ Adjoint(m⋆n), Adjoint(p⋆m) ⊢ Adjoint(p⋆m)) |
52 | 47 |
|
53 |
| - rrule_test(*, n⋆p, Transpose.(m⋆₂n), (m⋆₂p)) |
54 |
| - rrule_test(*, n⋆p, Adjoint.(m⋆₂n), (m⋆₂p)) |
| 48 | + test_rrule(*, Transpose(m⋆n) ⊢ Transpose(m⋆n), (m⋆p)) |
| 49 | + test_rrule(*, Adjoint(m⋆n) ⊢ Adjoint(m⋆n), (m⋆p)) |
55 | 50 |
|
56 |
| - rrule_test(*, n⋆p, (n⋆₂m), Transpose.(p⋆₂m)) |
57 |
| - rrule_test(*, n⋆p, (n⋆₂m), Adjoint.(p⋆₂m)) |
| 51 | + test_rrule(*, (n⋆m), Transpose(p⋆m) ⊢ Transpose(p⋆m)) |
| 52 | + test_rrule(*, (n⋆m), Adjoint(p⋆m) ⊢ Adjoint(p⋆m)) |
58 | 53 | end
|
59 | 54 | end
|
60 | 55 | end
|
61 | 56 |
|
62 | 57 | @testset "Covector * Vector n=$n" for n in (3, 5)
|
63 | 58 | @testset "$f" for f in (adjoint, transpose)
|
64 | 59 | # This should be same as dot product and give a scalar
|
65 |
| - rrule_test(*, ⋆(), f.(⋆₂(n)), ⋆₂(n)) |
| 60 | + test_rrule(*, f(⋆(n)) ⊢ f(⋆(n)), ⋆(n)) |
66 | 61 | end
|
67 | 62 | end
|
68 | 63 | end
|
|
73 | 68 | for n in 3:5, m in 3:5
|
74 | 69 | A = randn(m, n)
|
75 | 70 | B = randn(m, n)
|
76 |
| - Ȳ = randn(size(f(A, B))) |
77 |
| - rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n))) |
| 71 | + test_rrule(f, A, B) |
78 | 72 | end
|
79 | 73 | end
|
80 | 74 | @testset "Vector" begin
|
81 | 75 | x = randn(10)
|
82 | 76 | y = randn(10)
|
83 |
| - ȳ = randn(size(f(x, y))...) |
84 |
| - rrule_test(f, ȳ, (x, randn(10)), (y, randn(10))) |
| 77 | + test_rrule(f, x, y) |
85 | 78 | end
|
86 | 79 | if f == (\)
|
87 | 80 | @testset "Matrix $f Vector" begin
|
88 | 81 | X = randn(10, 4)
|
89 | 82 | y = randn(10)
|
90 |
| - ȳ = randn(size(f(X, y))...) |
91 |
| - rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10))) |
| 83 | + test_rrule(f, X, y) |
92 | 84 | end
|
93 | 85 | @testset "Vector $f Matrix" begin
|
94 | 86 | x = randn(10)
|
95 | 87 | Y = randn(10, 4)
|
96 |
| - ȳ = randn(size(f(x, Y))...) |
97 |
| - rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y)))) |
| 88 | + test_rrule(f, x, Y; output_tangent=Transpose(rand(4))) |
98 | 89 | end
|
99 | 90 | end
|
100 | 91 | end
|
101 | 92 | @testset "/ and \\ Scalar-AbstractArray" begin
|
102 | 93 | A = randn(3, 4, 5)
|
103 |
| - Ā = randn(3, 4, 5) |
104 |
| - Ȳ = randn(3, 4, 5) |
105 |
| - rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3)) |
106 |
| - rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā)) |
| 94 | + test_rrule(/, A, 7.2) |
| 95 | + test_rrule(\, 7.2, A) |
107 | 96 | end
|
108 | 97 |
|
109 | 98 |
|
110 | 99 | @testset "negation" begin
|
111 | 100 | A = randn(4, 4)
|
112 | 101 | Ā = randn(4, 4)
|
113 |
| - Ȳ = randn(4, 4) |
114 |
| - rrule_test(-, Ȳ, (A, Ā)) |
115 |
| - rrule_test(-, Diagonal(Ȳ), (Diagonal(A), Diagonal(Ā))) |
| 102 | + |
| 103 | + test_rrule(-, A) |
| 104 | + test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā)) |
116 | 105 | end
|
117 | 106 | end
|
0 commit comments