Skip to content

Commit 7f07352

Browse files
committed
Add PositiveDefinite implementation and tests
1 parent e0e41bf commit 7f07352

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

src/parameters_matrix.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ end
4141
"""
4242
positive_semidefinite(X::AbstractMatrix{<:Real})
4343
44-
Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The argument `X` needs to
45-
be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix).
44+
Produce a parameter whose `value` is constrained to be a positive-semidefinite matrix. The
45+
argument `X` needs to be a positive-definite matrix
46+
(see https://en.wikipedia.org/wiki/Definite_matrix).
4647
4748
The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector.
4849
@@ -57,6 +58,23 @@ function positive_semidefinite(X::AbstractMatrix{<:Real})
5758
return PositiveSemiDefinite(tril_to_vec(cholesky(X).L))
5859
end
5960

61+
"""
62+
positive_definite(X::AbstractMatrix{<:Real}, ε = eps(T))
63+
64+
Produce a parameter whose `value` is constrained to be a strictly positive-semidefinite
65+
matrix. The argument `X` minus `ε` times the identity needs to be a positive-definite matrix
66+
(see https://en.wikipedia.org/wiki/Definite_matrix). The optional second argument `ε` must
67+
be a positive real number.
68+
69+
The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector.
70+
"""
71+
function positive_definite(X::AbstractMatrix{T}, ε = eps(T)) where T <: Real
72+
ε > 0 || throw(ArgumentError("ε is not positive. Use `positive_semidefinite` instead."))
73+
_X = X - ε * I
74+
isposdef(_X) || throw(ArgumentError("X-ε*I is not positive-definite for ε="))
75+
return PositiveDefinite(tril_to_vec(cholesky(_X).L), ε)
76+
end
77+
6078
struct PositiveSemiDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter
6179
L::TL
6280
end
@@ -73,6 +91,21 @@ function flatten(::Type{T}, X::PositiveSemiDefinite) where {T<:Real}
7391
return v, unflatten_PositiveSemiDefinite
7492
end
7593

94+
struct PositiveDefinite{TL<:AbstractVector{<:Real}, Tε<:Real} <: AbstractParameter
95+
L::TL
96+
ε::Tε
97+
end
98+
99+
Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L
100+
101+
value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) + X.ε * I
102+
103+
function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real}
104+
v, unflatten_v = flatten(T, X.L)
105+
unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new), X.ε)
106+
return v, unflatten_PositiveDefinite
107+
end
108+
76109
# Convert a vector to lower-triangular matrix
77110
function vec_to_tril(v::AbstractVector{T}) where {T}
78111
n_vec = length(v)

test/parameters_matrix.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,24 @@ using ParameterHandling: vec_to_tril, tril_to_vec
5454
@test vec_to_tril(Δl) == tril(ΔL)
5555
ChainRulesTestUtils.test_rrule(vec_to_tril, x)
5656
end
57+
58+
@testset "positive_definite" begin
59+
X_mat = ParameterHandling.A_At(rand(3, 3)) # Create a positive definite object
60+
X = positive_definite(X_mat)
61+
@test isposdef(value(X))
62+
X.L .= 0 # zero the unconstrained value
63+
@test isposdef(value(X))
64+
@test_throws ArgumentError positive_definite(zeros(3, 3))
65+
@test_throws ArgumentError positive_definite(X_mat, 0.)
66+
test_parameter_interface(X)
67+
68+
x, re = flatten(X)
69+
Δl = first(
70+
Zygote.gradient(x) do x
71+
X = re(x)
72+
return logdet(value(X))
73+
end,
74+
)
75+
ChainRulesTestUtils.test_rrule(vec_to_tril, x)
76+
end
5777
end

0 commit comments

Comments
 (0)