Skip to content

Commit 5d880fe

Browse files
author
Miha Zgubic
committed
replace tests with test_rrule calls
1 parent 5ba3dbd commit 5d880fe

File tree

2 files changed

+11
-31
lines changed

2 files changed

+11
-31
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ version = "0.7.55"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
810
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,15 @@ end
8181
@testset "svd" begin
8282
for n in [4, 6, 10], m in [3, 5, 10]
8383
X = randn(n, m)
84-
F, dX_pullback = rrule(svd, X)
85-
for p in [:U, :S, :V, :Vt]
86-
Y, dF_pullback = rrule(getproperty, F, p)
87-
= randn(size(Y)...)
88-
89-
dself1, dF, dp = dF_pullback(Ȳ)
90-
@test dself1 === NO_FIELDS
91-
@test dp === DoesNotExist()
92-
93-
dself2, dX = dX_pullback(dF)
94-
@test dself2 === NO_FIELDS
95-
X̄_ad = unthunk(dX)
96-
X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X))
97-
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))
84+
@testset "($n by $m) svd" begin
85+
test_rrule(svd, X)
86+
end
87+
@testset "($n by $m) getproperty" begin
88+
F = svd(X)
89+
test_rrule(getproperty, F, :U; check_inferred=false)
90+
test_rrule(getproperty, F, :S; check_inferred=false)
91+
test_rrule(getproperty, F, :Vt; check_inferred=false)
92+
test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=adjoint(rand(n, m)))
9893
end
9994
end
10095

@@ -122,23 +117,6 @@ end
122117
end
123118
end
124119

125-
@testset "+" begin
126-
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
127-
F, dX_pullback = rrule(svd, X)
128-
= Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
129-
for p in [:U, :S, :V, :Vt]
130-
Y, dF_pullback = rrule(getproperty, F, p)
131-
= ones(size(Y)...)
132-
dself, dF, dp = dF_pullback(Ȳ)
133-
@test dself === NO_FIELDS
134-
@test dp === DoesNotExist()
135-
+= dF
136-
end
137-
@test.U ones(3, 2) atol=1e-6
138-
@test.S ones(2) atol=1e-6
139-
@test.Vt 2 * ones(2, 2) atol=1e-6 # * 2 because V and Vt both accumulate to Vt
140-
end
141-
142120
@testset "Helper functions" begin
143121
X = randn(10, 10)
144122
Y = randn(10, 10)

0 commit comments

Comments
 (0)