|
81 | 81 | @testset "svd" begin
|
82 | 82 | for n in [4, 6, 10], m in [3, 5, 10]
|
83 | 83 | X = randn(n, m)
|
84 |
| - F, dX_pullback = rrule(svd, X) |
85 |
| - for p in [:U, :S, :V, :Vt] |
86 |
| - Y, dF_pullback = rrule(getproperty, F, p) |
87 |
| - Ȳ = randn(size(Y)...) |
88 |
| - |
89 |
| - dself1, dF, dp = dF_pullback(Ȳ) |
90 |
| - @test dself1 === NO_FIELDS |
91 |
| - @test dp === DoesNotExist() |
92 |
| - |
93 |
| - dself2, dX = dX_pullback(dF) |
94 |
| - @test dself2 === NO_FIELDS |
95 |
| - X̄_ad = unthunk(dX) |
96 |
| - X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)) |
97 |
| - @test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6)) |
| 84 | + @testset "($n by $m) svd" begin |
| 85 | + test_rrule(svd, X) |
| 86 | + end |
| 87 | + @testset "($n by $m) getproperty" begin |
| 88 | + F = svd(X) |
| 89 | + test_rrule(getproperty, F, :U; check_inferred=false) |
| 90 | + test_rrule(getproperty, F, :S; check_inferred=false) |
| 91 | + test_rrule(getproperty, F, :Vt; check_inferred=false) |
| 92 | + test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=adjoint(rand(n, m))) |
98 | 93 | end
|
99 | 94 | end
|
100 | 95 |
|
|
122 | 117 | end
|
123 | 118 | end
|
124 | 119 |
|
125 |
| - @testset "+" begin |
126 |
| - X = [1.0 2.0; 3.0 4.0; 5.0 6.0] |
127 |
| - F, dX_pullback = rrule(svd, X) |
128 |
| - X̄ = Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) |
129 |
| - for p in [:U, :S, :V, :Vt] |
130 |
| - Y, dF_pullback = rrule(getproperty, F, p) |
131 |
| - Ȳ = ones(size(Y)...) |
132 |
| - dself, dF, dp = dF_pullback(Ȳ) |
133 |
| - @test dself === NO_FIELDS |
134 |
| - @test dp === DoesNotExist() |
135 |
| - X̄ += dF |
136 |
| - end |
137 |
| - @test X̄.U ≈ ones(3, 2) atol=1e-6 |
138 |
| - @test X̄.S ≈ ones(2) atol=1e-6 |
139 |
| - @test X̄.Vt ≈ 2 * ones(2, 2) atol=1e-6 # * 2 because V and Vt both accumulate to Vt |
140 |
| - end |
141 |
| - |
142 | 120 | @testset "Helper functions" begin
|
143 | 121 | X = randn(10, 10)
|
144 | 122 | Y = randn(10, 10)
|
|
0 commit comments