Skip to content

Commit e6dd4ef

Browse files
harisorgngithub-actions[bot]torfjelde
authored
Add vectorize method for LKJCholesky (#485)
* using `LinearAlgebra.Cholesky` * add `vectorize` for `LKJCholesky` * add `vectorize` test * add forgotten `end` * Update test/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix typo * add `reconstruct` methods for LKJ/LKJCholesky inv bijectors * bump patch * bump Bijectors compat * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add Bijectors v0.13 compat * add `inittrans` method for `CholeskyVariate` * add `LKJ`/`LKJCholesky` tests Co-authored-by: torfjelde * include tests * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * make tests more accurate * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update test/lkj.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * test `LKJCholesky` for both `'U'` and `'L'` * remove unnecessary `float` wrap * Update test/lkj.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent 5f74696 commit e6dd4ef

File tree

7 files changed

+81
-3
lines changed

7 files changed

+81
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.0"
3+
version = "0.23.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ using Setfield: Setfield
1515
using ZygoteRules: ZygoteRules
1616
using LogDensityProblems: LogDensityProblems
1717

18+
using LinearAlgebra: Cholesky
19+
1820
using DocStringExtensions
1921

2022
using Random: Random

src/abstract_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ end
570570
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
571571
"""
572572
reconstruct_and_link(dist, val)
573-
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
573+
reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
574574
575575
Return linked `val` but reconstruct before linking, if necessary.
576576

src/utils.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ vectorize(d, r) = vec(r)
213213
vectorize(d::UnivariateDistribution, r::Real) = [r]
214214
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
215215
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
216+
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))
216217

217218
# NOTE:
218219
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
@@ -235,6 +236,13 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
235236
reconstruct(::UnivariateDistribution, val::Real) = val
236237
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
237238
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
239+
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
240+
function reconstruct(
241+
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
242+
)
243+
return copy(val)
244+
end
245+
238246
# TODO: Implement no-op `reconstruct` for general array variates.
239247

240248
reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
@@ -294,7 +302,12 @@ function inittrans(rng, dist::MatrixDistribution)
294302
sz = Bijectors.output_size(b, size(dist))
295303
return Bijectors.invlink(dist, randrealuni(rng, sz...))
296304
end
297-
305+
function inittrans(rng, dist::Distribution{CholeskyVariate})
306+
# Get the size of the unconstrained vector
307+
b = link_transform(dist)
308+
sz = Bijectors.output_size(b, size(dist))
309+
return Bijectors.invlink(dist, randrealuni(rng, sz...))
310+
end
298311
################################
299312
# Multi-sample initialisations #
300313
################################

test/lkj.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using Bijectors: pd_from_upper, pd_from_lower
2+
3+
function pd_from_triangular(X::AbstractMatrix, uplo::Char)
4+
return uplo == 'U' ? pd_from_upper(X) : pd_from_lower(X)
5+
end
6+
7+
@model lkj_prior_demo() = x ~ LKJ(2, 1)
8+
@model lkj_chol_prior_demo(uplo) = x ~ LKJCholesky(2, 1, uplo)
9+
10+
# Same for both distributions
11+
target_mean = vec(Matrix{Float64}(I, 2, 2))
12+
13+
_lkj_atol = 0.05
14+
15+
@testset "Sample from x ~ LKJ(2, 1)" begin
16+
model = lkj_prior_demo()
17+
# `SampleFromPrior` will sample in constrained space.
18+
@testset "SampleFromPrior" begin
19+
samples = sample(model, SampleFromPrior(), 1_000)
20+
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
21+
_lkj_atol
22+
end
23+
24+
# `SampleFromUniform` will sample in unconstrained space.
25+
@testset "SampleFromUniform" begin
26+
samples = sample(model, SampleFromUniform(), 1_000)
27+
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
28+
_lkj_atol
29+
end
30+
end
31+
32+
@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L']
33+
model = lkj_chol_prior_demo(uplo)
34+
# `SampleFromPrior` will sample in unconstrained space.
35+
@testset "SampleFromPrior" begin
36+
samples = sample(model, SampleFromPrior(), 1_000)
37+
# Build correlation matrix from factor
38+
corr_matrices = map(samples) do s
39+
M = reshape(s.metadata.vals, (2, 2))
40+
pd_from_triangular(M, uplo)
41+
end
42+
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
43+
end
44+
45+
# `SampleFromUniform` will sample in unconstrained space.
46+
@testset "SampleFromUniform" begin
47+
samples = sample(model, SampleFromUniform(), 1_000)
48+
# Build correlation matrix from factor
49+
corr_matrices = map(samples) do s
50+
M = reshape(s.metadata.vals, (2, 2))
51+
pd_from_triangular(M, uplo)
52+
end
53+
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
54+
end
55+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ include("test_util.jl")
5050
include("serialization.jl")
5151

5252
include("loglikelihoods.jl")
53+
54+
include("lkj.jl")
5355
end
5456

5557
@testset "compat" begin

test/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,10 @@
4242
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
4343
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
4444
end
45+
46+
@testset "vectorize" begin
47+
dist = LKJCholesky(2, 1)
48+
x = rand(dist)
49+
@test vectorize(dist, x) == vec(x.UL)
50+
end
4551
end

0 commit comments

Comments
 (0)