|
24 | 24 | end
|
25 | 25 | end
|
26 | 26 |
|
27 |
| - @testset "cross" |
| 27 | + @testset "cross" begin |
28 | 28 | test_frule(cross, randn(3), randn(3))
|
29 | 29 | test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))
|
30 | 30 | test_rrule(cross, randn(3), randn(3))
|
|
47 | 47 | @test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
|
48 | 48 | @test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
|
49 | 49 | end
|
50 |
| - #TODO Everything after this point |
| 50 | + |
51 | 51 | @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
|
| 52 | + test_frule(pinv, F(randn(T, 3)) ⊢ F(randn(T, 3))) |
| 53 | + test_rrule(pinv, F(randn(T, 3))) |
| 54 | + |
| 55 | + # Check types. |
| 56 | + # TODO: Do we need this still? |
52 | 57 | x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3))
|
53 | 58 | y = pinv(x)
|
54 | 59 | Δy = copyto!(similar(y), randn(T, 3))
|
55 |
| - test_frule(pinv, (x, ẋ)) |
| 60 | + |
56 | 61 | y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
|
57 | 62 | @test y_fwd isa typeof(y)
|
58 | 63 | @test ∂y_fwd isa typeof(y)
|
59 |
| - test_rrule(pinv, Δy, (x, x̄)) |
| 64 | + |
60 | 65 | y_rev, back = rrule(pinv, x)
|
61 | 66 | @test y_rev isa typeof(y)
|
62 | 67 | @test back(Δy)[2] isa typeof(x)
|
63 | 68 | end
|
64 |
| - @testset "Matrix{$T} with size ($m,$3)" for T in (Float64, ComplexF64), |
| 69 | + @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), |
65 | 70 | m in 1:3,
|
66 |
| - 3 in 1:3 |
| 71 | + n in 1:3 |
67 | 72 |
|
68 |
| - X, Ẋ, X̄ = randn(T, m, 3), randn(T, m, 3), randn(T, m, 3) |
69 |
| - ΔY = randn(T, size(pinv(X))...) |
70 |
| - test_frule(pinv, (X, Ẋ)) |
71 |
| - test_rrule(pinv, ΔY, (X, X̄)) |
| 73 | + test_frule(pinv, randn(T, m, n)) |
| 74 | + test_rrule(pinv, randn(T, m, n)) |
72 | 75 | end
|
73 | 76 | end
|
74 | 77 | @testset "$f" for f in (det, logdet)
|
|
77 | 80 | test_scalar(f, b)
|
78 | 81 | end
|
79 | 82 | @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64)
|
| 83 | + B = generate_well_conditioned_matrix(T, 4) |
80 | 84 | if f === logdet && float(T) <: Float32
|
81 |
| - kwargs = (atol=1e-5, rtol=1e-5) |
| 85 | + test_frule(f, B; atol=1e-5, rtol=1e-5) |
| 86 | + test_rrule(f, B; atol=1e-5, rtol=1e-5) |
82 | 87 | else
|
83 |
| - kwargs = NamedTuple() |
| 88 | + test_frule(f, B) |
| 89 | + test_rrule(f, B) |
84 | 90 | end
|
85 |
| - B = generate_well_conditioned_matrix(T, 4) |
86 |
| - test_frule(f, (B, randn(T, 4, 4)); kwargs...) |
87 |
| - test_rrule(f, randn(T), (B, randn(T, 4, 4)); kwargs...) |
88 | 91 | end
|
89 | 92 | end
|
90 | 93 | @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
|
91 | 94 | B = randn(T, 4, 4)
|
92 |
| - test_frule(logabsdet, (B, randn(T, 4, 4))) |
93 |
| - test_rrule(logabsdet, (randn(), randn(T)), (B, randn(T, 4, 4))) |
| 95 | + test_frule(logabsdet, B) |
| 96 | + test_rrule(logabsdet, B) |
94 | 97 | # test for opposite sign of determinant
|
95 |
| - test_frule(logabsdet, (-B, randn(T, 4, 4))) |
96 |
| - test_rrule(logabsdet, (randn(), randn(T)), (-B, randn(T, 4, 4))) |
| 98 | + test_frule(logabsdet, -B) |
| 99 | + test_rrule(logabsdet, -B) |
97 | 100 | end
|
98 | 101 | @testset "tr" begin
|
99 |
| - test_frule(tr, (randn(4, 4), randn(4, 4))) |
100 |
| - test_rrule(tr, randn(), (randn(4, 4), randn(4, 4))) |
| 102 | + test_frule(tr, randn(4, 4)) |
| 103 | + test_rrule(tr, randn(4, 4)) |
101 | 104 | end
|
102 |
| - ==# |
103 | 105 | end
|
0 commit comments