1+ # Create a struct to hold the fields of the Woodbury.
2+ # We do this because the Woodbury cannot represent it's own tangent.
3+ # I.e. for Y = f(W,...), W̄ = ∂Y / ∂W, is not necessarily a valid Woodbury.
4+ # Consider, e.g. the case of a Positive Diagonal matrix
5+ struct WoodburyLike
6+ A
7+ D
8+ S
9+ end
10+
11+ # Overwrite the generic to_bec and replace with the almost identical Woodbury specific.
12+ # This means that in FiniteDifferences, the WoodburyLike matrix is created instead of the Woodbury.
13+ # Because the construction is forced there, this would bypass the valdidation checks on the constructor.
14+ function FiniteDifferences. to_vec (x:: T ) where {T<: WoodburyPDMat }
15+ val_vecs_and_backs = map (name -> to_vec (getfield (x, name)), fieldnames (T))
16+ vals = first .(val_vecs_and_backs)
17+ backs = last .(val_vecs_and_backs)
18+
19+ v, vals_from_vec = to_vec (vals)
20+ function structtype_from_vec (v:: Vector{<:Real} )
21+ val_vecs = vals_from_vec (v)
22+ values = map ((b, v) -> b (v), backs, val_vecs)
23+ WoodburyLike (values... )
24+ end
25+ return v, structtype_from_vec
26+ end
27+
28+ # Assign some algebra for the WoodburyLike.
29+ WoodburyPDMat (S:: WoodburyLike ) = WoodburyPDMat (S. A, S. D, S. S)
30+ Base.:* (A:: AbstractVecOrMat , B:: WoodburyLike ) = A * WoodburyPDMat (B)
31+ Base.:* (A:: WoodburyLike , B:: AbstractVecOrMat ) = WoodburyPDMat (A) * B
32+ Base.:* (A:: Real , B:: WoodburyLike ) = A * WoodburyPDMat (B)
33+ Base.:* (A:: WoodburyLike , B:: Real ) = WoodburyPDMat (A) * B
34+ LinearAlgebra. dot (A, B:: WoodburyLike ) = dot (A, WoodburyPDMat (B))
35+
36+ @testset " ChainRules" begin
37+
38+ W = WoodburyPDMat (rand (4 ,2 ), Diagonal (rand (2 ,)), Diagonal (rand (4 ,)))
39+ R = 2.0
40+ D = Diagonal (rand (4 ,))
41+
42+ @testset " *(Matrix-Woodbury)" begin
43+ test_rrule (* , D, W)
44+ test_rrule (* , W, D)
45+ test_rrule (* , rand (4 ,4 ), W)
46+ end
47+
48+ @testset " *(Real-Woodbury" begin
49+ @testset " Matrix Tangent" begin
50+ # ##
51+
52+ primal = R * W
53+
54+ # Matrix Tangent
55+ T = rand_tangent (Matrix (primal))
56+ f_jvp = j′vp (ChainRulesTestUtils. _fdm, x -> Matrix (* (x... )), T, (R, W))[1 ]
57+ R̄ = dot (T, W' )
58+ W̄ = conj (R) * T
59+
60+ R̄ ≈ f_jvp[1 ]
61+ (W. D * W. A' * W̄' + W. D * W. A' * W̄) ≈ f_jvp[2 ]. A' # A transpose.
62+ Diagonal (W. A' * (W̄) * W. A) ≈ f_jvp[2 ]. D # D
63+ Diagonal (W̄) ≈ f_jvp[2 ]. S # S
64+
65+ # Cannot get this to work. Here the T will be
66+ # T = rand_tangent(primal::WoodburyPDMat) which breaks.
67+ # test_rrule(*, 5.0, W)
68+
69+ end
70+ end
71+ end
0 commit comments