Skip to content

Commit 2c68930

Browse files
AlexRobsonmzgubic
andauthored
Extend Woodbury multiplication to Diagonals (#14)
* Minor changes * Bump version * Add check that Diagonal matricies are posdef * Add ArgumentError test * Update Project.toml Co-authored-by: Miha Zgubic <[email protected]> Co-authored-by: Miha Zgubic <[email protected]>
1 parent 160fc34 commit 2c68930

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PDMatsExtras"
22
uuid = "2c7acb1b-7338-470f-b38f-951d2bcb9193"
33
authors = ["Invenia Technical Computing"]
4-
version = "2.2.0"
4+
version = "2.3.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/woodbury_pd_mat.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,14 @@ end
8888
# NOTE: the parameterisation to scale up the Woodbury matrix is not unique. Here we
8989
# implement one way to scale it.
9090
*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
91+
*(c::Real, a::WoodburyPDMat) = a * c
92+
function *(a::WoodburyPDMat, c::Diagonal{T}) where {T<:Real}
93+
isposdef(c) || throw(ArgumentError("c must be positive definite"))
94+
WoodburyPDMat(sqrt(c) * a.A, a.D, a.S * c)
95+
end
96+
*(c::Diagonal{T}, a::WoodburyPDMat) where {T<:Real} = a * c
97+
function *(c1::Diagonal{T}, a::WoodburyPDMat, c2::Diagonal{T}) where {T<:Real}
98+
isposdef(c1) || throw(ArgumentError("c1 must be positive definite"))
99+
isposdef(c2) || throw(ArgumentError("c2 must be positive definite"))
100+
WoodburyPDMat(sqrt(c1) * sqrt(c2) * a.A, a.D, c1 * a.S * c2)
101+
end

test/woodbury_pd_mat.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@
4343
c = 2.0
4444
@test c * W == W * c
4545
@test c * W_dense c * W atol=1e-6
46+
@test (c * W) isa WoodburyPDMat
47+
48+
c = Diagonal(2.0 * ones(4,))
49+
@test c * W == W * c
50+
@test c * W_dense c * W atol=1e-6
51+
@test (c * W) isa WoodburyPDMat
52+
53+
c1 = Diagonal(2.0 * ones(4,))
54+
c2 = Diagonal(3.0 * ones(4,))
55+
c_neg = Diagonal([1,2,-2,3])
56+
57+
@test c2 * W * c1 == c1 * W * c2
58+
@test c1 * W * c2 c1 * W_dense * c2
59+
@test (c1 * W * c2) isa WoodburyPDMat
60+
61+
@test_throws(ArgumentError, c_neg * W)
62+
@test_throws(ArgumentError, c_neg * W * c2)
63+
@test_throws(ArgumentError, c1 * W * c_neg)
64+
4665
end
4766

4867
@testset "MvNormal logpdf" begin

0 commit comments

Comments
 (0)