Skip to content

Commit 10bc061

Browse files
author
Alex Robson
committed
Tidy up tests. White space clear
1 parent c586180 commit 10bc061

File tree

3 files changed

+19
-20
lines changed

3 files changed

+19
-20
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
name = "PDMatsExtras"
22
uuid = "2c7acb1b-7338-470f-b38f-951d2bcb9193"
33
authors = ["Invenia Technical Computing"]
4-
version = "2.5.2"
4+
version = "2.5.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
9-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
108
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
119
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1210
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -18,10 +16,12 @@ ChainRulesCore = "1"
1816
Distributions = "0.23, 0.24"
1917
FiniteDifferences = "0.11, 0.12"
2018
PDMats = "0.9, 0.10, 0.11"
21-
Zygote = "0.4, 0.5, 0.6"
19+
Zygote = "0.6"
2220
julia = "1"
2321

2422
[extras]
23+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
24+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2525
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2626
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2727
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -30,4 +30,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3030
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3131

3232
[targets]
33-
test = ["Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
33+
test = ["ChainRulesTestUtils", "ChainRules", "Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]

src/chainrules.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ function ChainRulesCore.rrule(
55
) where {T, TA, TD, TS}
66
project_A = ProjectTo(A)
77
project_B = ProjectTo(B)
8-
98
function times_pullback(ȳ::AbstractMatrix)
109
= unthunk(ȳ)
1110
= dot(Ȳ, B)
@@ -38,9 +37,4 @@ function ChainRulesCore.ProjectTo(W::WoodburyPDMat)
3837
return Tangent{typeof(W)}(; A = (W̄), D = (W̄), S = (W̄))
3938
end
4039
return dW
41-
end
42-
43-
44-
45-
46-
40+
end

test/chainrules.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,27 @@ LinearAlgebra.dot(A, B::WoodburyLike) = dot(A, WoodburyPDMat(B))
5050
###
5151

5252
primal = R * W
53-
5453
# Matrix Tangent
5554
T = rand_tangent(Matrix(primal))
55+
res = ChainRulesCore.rrule(*, R, W)
5656
f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> Matrix(*(x...)), T, (R, W))[1]
57-
= dot(T, W')
58-
= conj(R) * T
5957

60-
f_jvp[1]
61-
(W.D * W.A' *' + W.D * W.A' * W̄) f_jvp[2].A' # A transpose.
62-
Diagonal(W.A' * (W̄) * W.A) f_jvp[2].D # D
63-
Diagonal(W̄) f_jvp[2].S # S
58+
# Expected
59+
= ProjectTo(R)(dot(T, W'))
60+
= ProjectTo(W)(conj(R) * T)
61+
62+
R̄_rrule = unthunk(res[2](T)[2])
63+
W̄_rrule = unthunk(res[2](T)[3])
64+
65+
@test res[1] == primal
66+
@test R̄_rrule f_jvp[1]
67+
@test W̄_rrule.A f_jvp[2].A
68+
@test W̄_rrule.D f_jvp[2].D
69+
@test W̄_rrule.S f_jvp[2].S
6470

6571
# Cannot get this to work. Here the T will be
6672
# T = rand_tangent(primal::WoodburyPDMat) which breaks.
6773
# test_rrule(*, 5.0, W)
68-
6974
end
7075
end
7176
end

0 commit comments

Comments
 (0)