|
1 | 1 | @testset "dense" begin
|
2 | 2 | @testset "dot" begin
|
3 | 3 | @testset "Vector{$T}" for T in (Float64, ComplexF64)
|
4 |
| - M = 3 |
5 |
| - x, y = randn(T, M), randn(T, M) |
6 |
| - ẋ, ẏ = randn(T, M), randn(T, M) |
7 |
| - x̄, ȳ = randn(T, M), randn(T, M) |
8 |
| - frule_test(dot, (x, ẋ), (y, ẏ)) |
9 |
| - rrule_test(dot, randn(T), (x, x̄), (y, ȳ)) |
| 4 | + test_frule(dot, randn(T, 3), randn(T, 3)) |
| 5 | + test_rrule(dot, randn(T, 3), randn(T, 3)) |
10 | 6 | end
|
11 | 7 | @testset "Matrix{$T}" for T in (Float64, ComplexF64)
|
12 |
| - M, N = 3, 4 |
13 |
| - x, y = randn(T, M, N), randn(T, M, N) |
14 |
| - ẋ, ẏ = randn(T, M, N), randn(T, M, N) |
15 |
| - x̄, ȳ = randn(T, M, N), randn(T, M, N) |
16 |
| - frule_test(dot, (x, ẋ), (y, ẏ)) |
17 |
| - rrule_test(dot, randn(T), (x, x̄), (y, ȳ)) |
| 8 | + test_frule(dot, randn(T, 3, 4), randn(T, 3, 4)) |
| 9 | + test_rrule(dot, randn(T, 3, 4), randn(T, 3, 4)) |
18 | 10 | end
|
19 | 11 | @testset "Array{$T, 3}" for T in (Float64, ComplexF64)
|
20 |
| - M, N, P = 3, 4, 5 |
21 |
| - x, y = randn(T, M, N, P), randn(T, M, N, P) |
22 |
| - ẋ, ẏ = randn(T, M, N, P), randn(T, M, N, P) |
23 |
| - x̄, ȳ = randn(T, M, N, P), randn(T, M, N, P) |
24 |
| - frule_test(dot, (x, ẋ), (y, ẏ)) |
25 |
| - rrule_test(dot, randn(T), (x, x̄), (y, ȳ)) |
| 12 | + test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) |
| 13 | + test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) |
26 | 14 | end
|
27 | 15 | @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
|
28 |
| - M, N = 3, 4 |
29 |
| - x, A, y = randn(T, M), randn(T, M, N), randn(T, N) |
30 |
| - ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N) |
31 |
| - x̄, Abar, ȳ = randn(T, M), randn(T, M, N), randn(T, N) |
32 |
| - frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) |
33 |
| - rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) |
| 16 | + test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) |
| 17 | + test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) |
34 | 18 | end
|
35 | 19 | permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
|
36 | 20 | @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
|
37 |
| - M, N = 3, 4 |
38 |
| - x, A, y = rand(T, M), F(rand(T, N, M)), rand(T, N) |
39 |
| - ẋ, Adot, ẏ = rand(T, M), F(rand(T, N, M)), rand(T, N) |
40 |
| - x̄, Abar, ȳ = rand(T, M), F(rand(T, N, M)), rand(T, N) |
41 |
| - frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ); rtol=1f-3) |
42 |
| - rrule_test(dot, float(rand(T)), (x, x̄), (A, Abar), (y, ȳ); rtol=1f-3) |
| 21 | + A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3)) |
| 22 | + test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) |
| 23 | + test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) |
43 | 24 | end
|
44 | 25 | end
|
45 |
| - @testset "cross" begin |
46 |
| - @testset "frule" begin |
47 |
| - @testset "$T" for T in (Float64, ComplexF64) |
48 |
| - n = 3 |
49 |
| - x, y = randn(T, n), randn(T, n) |
50 |
| - ẋ, ẏ = randn(T, n), randn(T, n) |
51 |
| - frule_test(cross, (x, ẋ), (y, ẏ)) |
52 |
| - end |
53 |
| - end |
54 |
| - @testset "rrule" begin |
55 |
| - n = 3 |
56 |
| - x, y = randn(n), randn(n) |
57 |
| - x̄, ȳ = randn(n), randn(n) |
58 |
| - ΔΩ = randn(n) |
59 |
| - rrule_test(cross, ΔΩ, (x, x̄), (y, ȳ)) |
60 |
| - end |
| 26 | + |
| 27 | + @testset "cross" |
| 28 | + test_frule(cross, randn(3), randn(3)) |
| 29 | + test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) |
| 30 | + test_rrule(cross, randn(3), randn(3)) |
| 31 | + # No complex support for rrule(cross,... |
61 | 32 | end
|
62 | 33 | @testset "pinv" begin
|
63 | 34 | @testset "$T" for T in (Float64, ComplexF64)
|
|
66 | 37 | @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T)
|
67 | 38 | end
|
68 | 39 | @testset "Vector{$T}" for T in (Float64, ComplexF64)
|
69 |
| - n = 3 |
70 |
| - x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n) |
71 |
| - tol, ṫol, t̄ol = 0.0, randn(), randn() |
72 |
| - Δy = copyto!(similar(pinv(x)), randn(T, n)) |
73 |
| - frule_test(pinv, (x, ẋ), (tol, ṫol)) |
| 40 | + test_frule(pinv, randn(T, 3), 0.0) |
| 41 | + test_frule(pinv, randn(T, 3), 0.0) |
| 42 | + |
| 43 | + # Checking types. TODO do we still need this? |
| 44 | + x = randn(T, 3) |
| 45 | + ẋ = randn(T, 3) |
| 46 | + Δy = copyto!(similar(pinv(x)), randn(T, 3)) |
74 | 47 | @test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
|
75 |
| - rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol)) |
76 | 48 | @test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
|
77 | 49 | end
|
| 50 | + #TODO Everything after this point |
78 | 51 | @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
|
79 |
| - n = 3 |
80 |
| - x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n)) |
| 52 | + x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3)) |
81 | 53 | y = pinv(x)
|
82 |
| - Δy = copyto!(similar(y), randn(T, n)) |
83 |
| - frule_test(pinv, (x, ẋ)) |
| 54 | + Δy = copyto!(similar(y), randn(T, 3)) |
| 55 | + test_frule(pinv, (x, ẋ)) |
84 | 56 | y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
|
85 | 57 | @test y_fwd isa typeof(y)
|
86 | 58 | @test ∂y_fwd isa typeof(y)
|
87 |
| - rrule_test(pinv, Δy, (x, x̄)) |
| 59 | + test_rrule(pinv, Δy, (x, x̄)) |
88 | 60 | y_rev, back = rrule(pinv, x)
|
89 | 61 | @test y_rev isa typeof(y)
|
90 | 62 | @test back(Δy)[2] isa typeof(x)
|
91 | 63 | end
|
92 |
| - @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), |
| 64 | + @testset "Matrix{$T} with size ($m,$3)" for T in (Float64, ComplexF64), |
93 | 65 | m in 1:3,
|
94 |
| - n in 1:3 |
| 66 | + 3 in 1:3 |
95 | 67 |
|
96 |
| - X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n) |
| 68 | + X, Ẋ, X̄ = randn(T, m, 3), randn(T, m, 3), randn(T, m, 3) |
97 | 69 | ΔY = randn(T, size(pinv(X))...)
|
98 |
| - frule_test(pinv, (X, Ẋ)) |
99 |
| - rrule_test(pinv, ΔY, (X, X̄)) |
| 70 | + test_frule(pinv, (X, Ẋ)) |
| 71 | + test_rrule(pinv, ΔY, (X, X̄)) |
100 | 72 | end
|
101 | 73 | end
|
102 | 74 | @testset "$f" for f in (det, logdet)
|
|
110 | 82 | else
|
111 | 83 | kwargs = NamedTuple()
|
112 | 84 | end
|
113 |
| - N = 3 |
114 |
| - B = generate_well_conditioned_matrix(T, N) |
115 |
| - frule_test(f, (B, randn(T, N, N)); kwargs...) |
116 |
| - rrule_test(f, randn(T), (B, randn(T, N, N)); kwargs...) |
| 85 | + B = generate_well_conditioned_matrix(T, 4) |
| 86 | + test_frule(f, (B, randn(T, 4, 4)); kwargs...) |
| 87 | + test_rrule(f, randn(T), (B, randn(T, 4, 4)); kwargs...) |
117 | 88 | end
|
118 | 89 | end
|
119 | 90 | @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
|
120 |
| - N = 3 |
121 |
| - B = randn(T, N, N) |
122 |
| - frule_test(logabsdet, (B, randn(T, N, N))) |
123 |
| - rrule_test(logabsdet, (randn(), randn(T)), (B, randn(T, N, N))) |
| 91 | + B = randn(T, 4, 4) |
| 92 | + test_frule(logabsdet, (B, randn(T, 4, 4))) |
| 93 | + test_rrule(logabsdet, (randn(), randn(T)), (B, randn(T, 4, 4))) |
124 | 94 | # test for opposite sign of determinant
|
125 |
| - frule_test(logabsdet, (-B, randn(T, N, N))) |
126 |
| - rrule_test(logabsdet, (randn(), randn(T)), (-B, randn(T, N, N))) |
| 95 | + test_frule(logabsdet, (-B, randn(T, 4, 4))) |
| 96 | + test_rrule(logabsdet, (randn(), randn(T)), (-B, randn(T, 4, 4))) |
127 | 97 | end
|
128 | 98 | @testset "tr" begin
|
129 |
| - N = 4 |
130 |
| - frule_test(tr, (randn(N, N), randn(N, N))) |
131 |
| - rrule_test(tr, randn(), (randn(N, N), randn(N, N))) |
| 99 | + test_frule(tr, (randn(4, 4), randn(4, 4))) |
| 100 | + test_rrule(tr, randn(), (randn(4, 4), randn(4, 4))) |
132 | 101 | end
|
| 102 | + ==# |
133 | 103 | end
|
0 commit comments