From 4b820f56a660b296078c5a8594e8454d05377d8e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:40:25 +0100 Subject: [PATCH 01/16] Transfer MarginalLogDensities extension from Turing See: https://github.com/TuringLang/Turing.jl/pull/2664 Co-authored-by: Sam Urmy --- Project.toml | 2 ++ ext/DynamicPPLMarginalLogDensitiesExt.jl | 36 +++++++++++++++++++ src/DynamicPPL.jl | 8 ++--- test/Project.toml | 1 + test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 27 ++++++++++++++ test/runtests.jl | 1 + 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 ext/DynamicPPLMarginalLogDensitiesExt.jl create mode 100644 test/ext/DynamicPPLMarginalLogDensitiesExt.jl diff --git a/Project.toml b/Project.toml index 5f11cba3f..56a681375 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] @@ -41,6 +42,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] +DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..8153d1522 --- /dev/null +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,36 @@ +module DynamicPPLMarginalLogDensitiesExt + +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using MarginalLogDensities: MarginalLogDensities + +_to_varname(n::Symbol) = VarName{n}() +_to_varname(n::VarName) = n + +function DynamicPPL.marginalize( + model::DynamicPPL.Model, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + getlogprob=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + kwargs..., +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + vns = map(_to_varname, varnames) + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space + varinfo_linked = DynamicPPL.link(varinfo, model) + + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) + mdl = MarginalLogDensities.MarginalLogDensity( + (x, _) -> LogDensityProblems.logdensity(f, x), + varinfo_linked[:], + varindices, + (), + method; + kwargs..., + ) + return mdl +end + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..f67cec6b7 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -122,6 +122,7 @@ export AbstractVarInfo, fix, unfix, predict, + marginalize, prefix, returned, to_submodel, @@ -199,10 +200,6 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -if !isdefined(Base, :get_extension) - using Requires -end - # Better error message if users forget to load JET if isdefined(Base.Experimental, :register_error_hint) function __init__() @@ -247,4 +244,7 @@ end # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ struct DynamicPPLTag end +# Extended in MarginalLogDensitiesExt +function marginalize end + end # module diff --git a/test/Project.toml b/test/Project.toml index 91a885e96..168f360b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..381f67e97 --- /dev/null +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,27 @@ +module MarginalLogDensitiesExtTests + +using DynamicPPL, Distributions, Test +using MarginalLogDensities +using ADTypes: AutoForwardDiff + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + return y ~ Normal(0, 1) + end + model = demo() + # Marginalize out `x`. + + for vn in [@varname(x), :x] + for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] + marginalized = marginalize( + model, [vn], getlogprob; hess_adtype=AutoForwardDiff() + ) + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5 + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..40960884e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ include("test_util.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") + include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") From 8fab001e737c32bb1600e34af8013a09aae915f7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:48:14 +0100 Subject: [PATCH 02/16] Add documentation --- docs/Project.toml | 1 + docs/make.jl | 1 + docs/src/api.md | 9 +++++ ext/DynamicPPLMarginalLogDensitiesExt.jl | 45 ++++++++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..cc0be339d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] diff --git a/docs/make.jl b/docs/make.jl index 9c59cb06b..4185672e1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,7 @@ using Distributions using DocumenterMermaid # load MCMCChains package extension to make `predict` available using MCMCChains +using MarginalLogDensities: MarginalLogDensities # Doctest setup DocMeta.setdocmeta!( diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..d93c58863 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -123,6 +123,15 @@ The `predict` function has two main methods: predict ``` +## Marginalization + +DynamicPPL provides the `marginalize` function to marginalize out variables from a model. +This requires `MarginalLogDensities.jl` to be loaded in your environment. + +```@docs +marginalize +``` + ### Basic Usage The typical workflow for posterior prediction involves: diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 8153d1522..f79e7a027 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,6 +6,51 @@ using MarginalLogDensities: MarginalLogDensities _to_varname(n::Symbol) = VarName{n}() _to_varname(n::VarName) = n +""" + marginalize( + model::DynamicPPL.Model, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + getlogprob=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + kwargs..., + ) + +Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal +log-density of the given `model`, after marginalizing out the variables specified in +`varnames`. + +The resulting object can be called with a vector of parameter values to compute the marginal +log-density. + +The `getlogprob` argument can be used to specify which kind of marginal log-density to +compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint +probability. + +By default the marginalization is performed with a Laplace approximation. Please see [the +MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) +for other options. + +## Example + +```jldoctest +julia> using DynamicPPL, Distributions, MarginalLogDensities + +julia> @model function demo() + x ~ Normal(1.0) + y ~ Normal(2.0) + end +demo (generic function with 2 methods) + +julia> marginalized = marginalize(demo(), [:x]); + +julia> # The resulting callable computes the marginal log-density of `y`. + marginalized([1.0]) +-1.4189385332046727 + +julia> logpdf(Normal(2.0), 1.0) +-1.4189385332046727 +``` +""" function DynamicPPL.marginalize( model::DynamicPPL.Model, varnames::AbstractVector{<:Union{Symbol,<:VarName}}, From 105013a63b9d2ff95b636ab79749fd7a151d5cdf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:54:28 +0100 Subject: [PATCH 03/16] Bump patch, add changelog --- HISTORY.md | 5 +++++ Project.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 0f22721d9..b2983a426 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.37.3 + +An extension for MarginalLogDensities.jl has been added. +Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model; please see the documentation for further information. + ## 0.37.2 Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well. diff --git a/Project.toml b/Project.toml index 56a681375..df764ffe2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.2" +version = "0.37.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 1b3e76b1c4e951223dcbc4f57372155352820fa2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 18:19:34 +0100 Subject: [PATCH 04/16] Add compat entry for MLD --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index df764ffe2..11a37316f 100644 --- a/Project.toml +++ b/Project.toml @@ -67,6 +67,7 @@ JET = "0.9, 0.10" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" +MarginalLogDensities = "0.4.1" MCMCChains = "6, 7" MacroTools = "0.5.6" Mooncake = "0.4.147" From aaca138c84dce3ac15acbd0653d5ebb5df5c153e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 18:25:06 +0100 Subject: [PATCH 05/16] Fix docs --- docs/make.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 4185672e1..de64a21f5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -25,7 +25,11 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() ), - modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], + modules=[ + DynamicPPL, + Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt), + Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt), + ], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] ], From 14f9f464153dafcf1edeab53bbb42c56410a47b8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 15 Sep 2025 16:37:10 +0100 Subject: [PATCH 06/16] Allow user to specify VarInfo used for marginalisation --- ext/DynamicPPLMarginalLogDensitiesExt.jl | 12 ++-- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 67 +++++++++++++++---- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index f79e7a027..75681b2bc 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -9,6 +9,7 @@ _to_varname(n::VarName) = n """ marginalize( model::DynamicPPL.Model, + varinfo::DynamicPPL.AbstractVarInfo=VarInfo(model), varnames::AbstractVector{<:Union{Symbol,<:VarName}}, getlogprob=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); @@ -54,22 +55,19 @@ julia> logpdf(Normal(2.0), 1.0) function DynamicPPL.marginalize( model::DynamicPPL.Model, varnames::AbstractVector{<:Union{Symbol,<:VarName}}, - getlogprob=DynamicPPL.getlogjoint, + varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model), + getlogprob::Function=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); kwargs..., ) # Determine the indices for the variables to marginalise out. - varinfo = DynamicPPL.typed_varinfo(model) vns = map(_to_varname, varnames) varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) # Construct the marginal log-density model. - # Use linked `varinfo` to that we're working in unconstrained space - varinfo_linked = DynamicPPL.link(varinfo, model) - - f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mdl = MarginalLogDensities.MarginalLogDensity( (x, _) -> LogDensityProblems.logdensity(f, x), - varinfo_linked[:], + varinfo[:], varindices, (), method; diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 381f67e97..5b91f2529 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,25 +1,66 @@ module MarginalLogDensitiesExtTests +using Bijectors: Bijectors using DynamicPPL, Distributions, Test using MarginalLogDensities using ADTypes: AutoForwardDiff @testset "MarginalLogDensities" begin - # Simple test case. - @model function demo() - x ~ MvNormal(zeros(2), [1, 1]) - return y ~ Normal(0, 1) + @testset "Basic usage" begin + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + return y ~ Normal(0, 1) + end + model = demo() + vi = VarInfo(model) + # Marginalize out `x`. + for vn in [@varname(x), :x] + for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] + marginalized = marginalize( + model, [vn], vi, getlogprob; hess_adtype=AutoForwardDiff() + ) + for y in range(-5, 5; length=100) + @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 + end + end + end end - model = demo() - # Marginalize out `x`. - for vn in [@varname(x), :x] - for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] - marginalized = marginalize( - model, [vn], getlogprob; hess_adtype=AutoForwardDiff() - ) - # Compute the marginal log-density of `y = 0.0`. - @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5 + @testset "Respects linked status of VarInfo" begin + @model function f() + x ~ Normal() + return y ~ Beta(2, 2) + end + model = f() + vi_unlinked = VarInfo(model) + vi_linked = DynamicPPL.link(vi_unlinked, model) + + @testset "unlinked VarInfo" begin + mx = marginalize(model, [@varname(x)], vi_unlinked) + for x in range(0.01, 0.99; length=10) + @test mx([x]) ≈ logpdf(Beta(2, 2), x) + end + # generally when marginalising Beta it doesn't go to zero + my = marginalize(model, [@varname(y)], vi_unlinked) + diff = my([0.0]) - logpdf(Normal(), 0.0) + for x in range(-5, 5; length=10) + @test my([x]) ≈ logpdf(Normal(), x) + diff + end + end + + @testset "linked VarInfo" begin + mx = marginalize(model, [@varname(x)], vi_linked) + binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) + for y_linked in range(-5, 5; length=10) + y_unlinked = binv(y_linked) + @test mx([y_linked]) ≈ logpdf(Beta(2, 2), y_unlinked) + end + # generally when marginalising Beta it doesn't go to zero + my = marginalize(model, [@varname(y)], vi_linked) + diff = my([0.0]) - logpdf(Normal(), 0.0) + for x in range(-5, 5; length=10) + @test my([x]) ≈ logpdf(Normal(), x) + diff + end end end end From 00c08b28de7681a483574ce33eda6e6da1d1a1a9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 16 Sep 2025 15:45:20 +0100 Subject: [PATCH 07/16] Use linked varinfo by default --- Project.toml | 6 ++---- ext/DynamicPPLMarginalLogDensitiesExt.jl | 8 ++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 11a37316f..b72114f61 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -42,8 +41,8 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] -DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] [compat] @@ -67,14 +66,13 @@ JET = "0.9, 0.10" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" -MarginalLogDensities = "0.4.1" MCMCChains = "6, 7" MacroTools = "0.5.6" +MarginalLogDensities = "0.4.1" Mooncake = "0.4.147" OrderedCollections = "1" Printf = "1.10" Random = "1.6" -Requires = "1" Statistics = "1" Test = "1.6" julia = "1.10.8" diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 75681b2bc..3d3aef360 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -9,8 +9,8 @@ _to_varname(n::VarName) = n """ marginalize( model::DynamicPPL.Model, - varinfo::DynamicPPL.AbstractVarInfo=VarInfo(model), varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), getlogprob=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); kwargs..., @@ -23,6 +23,10 @@ log-density of the given `model`, after marginalizing out the variables specifie The resulting object can be called with a vector of parameter values to compute the marginal log-density. +You can specify the `varinfo` to use for the model. By default we use a linked `VarInfo`, +meaning that the resulting log-density function accepts parameters that have been +transformed to unconstrained space. + The `getlogprob` argument can be used to specify which kind of marginal log-density to compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint probability. @@ -55,7 +59,7 @@ julia> logpdf(Normal(2.0), 1.0) function DynamicPPL.marginalize( model::DynamicPPL.Model, varnames::AbstractVector{<:Union{Symbol,<:VarName}}, - varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model), + varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), getlogprob::Function=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); kwargs..., From e9eabb469c2250713db8e977989830f1bb22d2f5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 17 Sep 2025 22:59:33 +0100 Subject: [PATCH 08/16] Make the non-essential stuff all keyword arguments --- ext/DynamicPPLMarginalLogDensitiesExt.jl | 29 +++++++++++-------- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 16 ++++++---- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 3d3aef360..9652efc03 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -9,7 +9,7 @@ _to_varname(n::VarName) = n """ marginalize( model::DynamicPPL.Model, - varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}; varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), getlogprob=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); @@ -23,17 +23,22 @@ log-density of the given `model`, after marginalizing out the variables specifie The resulting object can be called with a vector of parameter values to compute the marginal log-density. -You can specify the `varinfo` to use for the model. By default we use a linked `VarInfo`, -meaning that the resulting log-density function accepts parameters that have been -transformed to unconstrained space. +## Keyword arguments -The `getlogprob` argument can be used to specify which kind of marginal log-density to -compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint -probability. +- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, + meaning that the resulting log-density function accepts parameters that have bee_FWDn + transformed to unconstrained space. -By default the marginalization is performed with a Laplace approximation. Please see [the -MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) -for other options. +- `getlogprob`: A function which specifies which kind of marginal log-density to compute. + Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint + probability. + +- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the + MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) + for other options. + +- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` + constructor. ## Example @@ -58,10 +63,10 @@ julia> logpdf(Normal(2.0), 1.0) """ function DynamicPPL.marginalize( model::DynamicPPL.Model, - varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}; varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), getlogprob::Function=DynamicPPL.getlogjoint, - method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), kwargs..., ) # Determine the indices for the variables to marginalise out. diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 5b91f2529..dec23af29 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -17,7 +17,11 @@ using ADTypes: AutoForwardDiff for vn in [@varname(x), :x] for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] marginalized = marginalize( - model, [vn], vi, getlogprob; hess_adtype=AutoForwardDiff() + model, + [vn]; + varinfo=vi, + getlogprob=getlogprob, + hess_adtype=AutoForwardDiff(), ) for y in range(-5, 5; length=100) @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 @@ -36,12 +40,13 @@ using ADTypes: AutoForwardDiff vi_linked = DynamicPPL.link(vi_unlinked, model) @testset "unlinked VarInfo" begin - mx = marginalize(model, [@varname(x)], vi_unlinked) + mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) for x in range(0.01, 0.99; length=10) @test mx([x]) ≈ logpdf(Beta(2, 2), x) end # generally when marginalising Beta it doesn't go to zero - my = marginalize(model, [@varname(y)], vi_unlinked) + # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 + my = marginalize(model, [@varname(y)]; varinfo=vi_unlinked) diff = my([0.0]) - logpdf(Normal(), 0.0) for x in range(-5, 5; length=10) @test my([x]) ≈ logpdf(Normal(), x) + diff @@ -49,14 +54,15 @@ using ADTypes: AutoForwardDiff end @testset "linked VarInfo" begin - mx = marginalize(model, [@varname(x)], vi_linked) + mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) for y_linked in range(-5, 5; length=10) y_unlinked = binv(y_linked) @test mx([y_linked]) ≈ logpdf(Beta(2, 2), y_unlinked) end # generally when marginalising Beta it doesn't go to zero - my = marginalize(model, [@varname(y)], vi_linked) + # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 + my = marginalize(model, [@varname(y)]; varinfo=vi_linked) diff = my([0.0]) - logpdf(Normal(), 0.0) for x in range(-5, 5; length=10) @test my([x]) ≈ logpdf(Normal(), x) + diff From 844ec4c74feea3073dfccaa97552b3603b718ac2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 17 Sep 2025 23:00:15 +0100 Subject: [PATCH 09/16] Fix docs --- docs/src/api.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index d93c58863..e0bd5572f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -123,15 +123,6 @@ The `predict` function has two main methods: predict ``` -## Marginalization - -DynamicPPL provides the `marginalize` function to marginalize out variables from a model. -This requires `MarginalLogDensities.jl` to be loaded in your environment. - -```@docs -marginalize -``` - ### Basic Usage The typical workflow for posterior prediction involves: @@ -145,6 +136,15 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a - `include_all=false` (default): Include only newly predicted variables - `include_all=true`: Include both parameters from the original chain and predicted variables +## Marginalization + +DynamicPPL provides the `marginalize` function to marginalize out variables from a model. +This requires `MarginalLogDensities.jl` to be loaded in your environment. + +```@docs +marginalize +``` + ## Models within models One can include models and call another model inside the model function with `left ~ to_submodel(model)`. From f4049c1685a532df92bdce0cef6ae77fdf943613 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 01:33:17 +0100 Subject: [PATCH 10/16] Add `VarInfo(::MarginalLogDensity)` method --- docs/make.jl | 3 + docs/src/api.md | 7 + ext/DynamicPPLMarginalLogDensitiesExt.jl | 137 ++++++++++++++++-- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 32 ++++ 4 files changed, 170 insertions(+), 9 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index de64a21f5..828b20658 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,6 +13,9 @@ using DocumenterMermaid using MCMCChains using MarginalLogDensities: MarginalLogDensities +# Need this to document a method which uses a type inside the extension... +DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt) + # Doctest setup DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true diff --git a/docs/src/api.md b/docs/src/api.md index e0bd5572f..db7a8d3b5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -145,6 +145,13 @@ This requires `MarginalLogDensities.jl` to be loaded in your environment. marginalize ``` +A `MarginalLogDensity` object acts as a function which maps non-marginalized parameter values to a marginal log-probability. +To retrieve a VarInfo object from it, you can use: + +```@docs +VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing}) +``` + ## Models within models One can include models and call another model inside the model function with `left ~ to_submodel(model)`. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 9652efc03..4870b350b 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,6 +6,16 @@ using MarginalLogDensities: MarginalLogDensities _to_varname(n::Symbol) = VarName{n}() _to_varname(n::VarName) = n +# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by +# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type +# below. +struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} + logdensity::L +end +function (lw::LogDensityFunctionWrapper)(x, _) + return LogDensityProblems.logdensity(lw.logdensity, x) +end + """ marginalize( model::DynamicPPL.Model, @@ -26,7 +36,7 @@ log-density. ## Keyword arguments - `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, - meaning that the resulting log-density function accepts parameters that have bee_FWDn + meaning that the resulting log-density function accepts parameters that have been transformed to unconstrained space. - `getlogprob`: A function which specifies which kind of marginal log-density to compute. @@ -60,6 +70,26 @@ julia> # The resulting callable computes the marginal log-density of `y`. julia> logpdf(Normal(2.0), 1.0) -1.4189385332046727 ``` + + +!!! warning + + The default usage of linked VarInfo means that, for example, optimization of the + marginal log-density can be performed in unconstrained space. However, care must be + taken if the model contains variables where the link transformation depends on a + marginalized variable. For example: + + ```julia + @model function f() + x ~ Normal() + y ~ truncated(Normal(); lower=x) + end + ``` + + Here, the support of `y`, and hence the link transformation used, depends on the value + of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of + `y` to log-probabilities. However, it will not be possible to use DynamicPPL to + correctly retrieve _unlinked_ values of `y`. """ function DynamicPPL.marginalize( model::DynamicPPL.Model, @@ -74,15 +104,104 @@ function DynamicPPL.marginalize( varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) - mdl = MarginalLogDensities.MarginalLogDensity( - (x, _) -> LogDensityProblems.logdensity(f, x), - varinfo[:], - varindices, - (), - method; - kwargs..., + mld = MarginalLogDensities.MarginalLogDensity( + LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + ) + return mld +end + +""" + VarInfo( + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + unmarginalized_params::Union{AbstractVector,Nothing}=nothing ) - return mdl + +Retrieve the `VarInfo` object used in the marginalisation process. + +If a Laplace approximation was used for the marginalisation, the values of the marginalized +parameters are also set to their mode (note that this only happens if the `mld` object has +been used to compute the marginal log-density at least once, so that the mode has been +computed). + +If a vector of `unmarginalized_params` is specified, the values for the corresponding +parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by +performing an optimization of the marginal log-density. + +All other aspects of the VarInfo, such as link status, are preserved from the original +VarInfo used in the marginalisation. + +!!! note + + The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be + updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the + model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, + vi))`). + +## Example + +```jldoctest +julia> using DynamicPPL, Distributions, MarginalLogDensities + +julia> @model function demo() + x ~ Normal() + y ~ Beta(2, 2) + end +demo (generic function with 2 methods) + +julia> # Note that by default `marginalize` uses a linked VarInfo. + mld = marginalize(demo(), [@varname(x)]); + +julia> using MarginalLogDensities: Optimization, OptimizationOptimJL + +julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. + y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) +OptimizationProblem. In-place: true +u0: 1-element Vector{Float64}: + 2.0 + +julia> # This tells us the optimal (linked) value of `y` is around 0. + opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) +retcode: Success +u: 1-element Vector{Float64}: + 4.88281250001733e-5 + +julia> # Get the VarInfo corresponding to the mode of `y`. + vi = VarInfo(mld, opt_solution.u); + +julia> # `x` is set to its mode (which for `Normal()` is zero). + vi[@varname(x)] +0.0 + +julia> # `y` is set to the optimal value we found above. + DynamicPPL.getindex_internal(vi, @varname(y)) +1-element Vector{Float64}: + 4.88281250001733e-5 + +julia> # To obtain values in the original constrained space, we can either + # use `getindex`: + vi[@varname(y)] +0.5000122070312476 + +julia> # Or invlink the entire VarInfo object using the model: + vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] +2-element Vector{Float64}: + 0.0 + 0.5000122070312476 +``` +""" +function DynamicPPL.VarInfo( + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + unmarginalized_params::Union{AbstractVector,Nothing}=nothing, +) + # Extract the original VarInfo. Its contents will in general be junk. + original_vi = mld.logdensity.logdensity.varinfo + # `mld.u` will contain the modes for any marginalized parameters + full_params = mld.u + # We can then set the values for any non-marginalized parameters + if unmarginalized_params !== nothing + full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params + end + return DynamicPPL.unflatten(original_vi, full_params) end end diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index dec23af29..4822ec8e6 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -69,6 +69,38 @@ using ADTypes: AutoForwardDiff end end end + + @testset "retrieving VarInfo from MLD" begin + @model function f() + x ~ Normal() + return y ~ Beta(2, 2) + end + model = f() + vi_unlinked = VarInfo(model) + vi_linked = DynamicPPL.link(vi_unlinked, model) + + @testset "unlinked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) + mx([0.5]) # evaluate at some point to force calculation of Laplace approx + vi = VarInfo(mx) + @test vi[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked + @test vi[@varname(x)] ≈ mode(Normal()) + @test vi[@varname(y)] ≈ 0.5 + end + + @testset "linked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) + mx([0.5]) # evaluate at some point to force calculation of Laplace approx + vi = VarInfo(mx) + @test vi[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is linked + binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) + @test vi[@varname(x)] ≈ mode(Normal()) + # when using getindex it always returns unlinked values + @test vi[@varname(y)] ≈ binv(0.5) + end + end end end From 144bee7c3ab695045aa9081ede9dbba5d1cfcce9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 12:29:14 +0100 Subject: [PATCH 11/16] Use new `cached_params` function --- Project.toml | 2 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index b72114f61..488906530 100644 --- a/Project.toml +++ b/Project.toml @@ -68,7 +68,7 @@ LinearAlgebra = "1.6" LogDensityProblems = "2" MCMCChains = "6, 7" MacroTools = "0.5.6" -MarginalLogDensities = "0.4.1" +MarginalLogDensities = "0.4.3" Mooncake = "0.4.147" OrderedCollections = "1" Printf = "1.10" diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 4870b350b..38dd67167 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -195,9 +195,10 @@ function DynamicPPL.VarInfo( ) # Extract the original VarInfo. Its contents will in general be junk. original_vi = mld.logdensity.logdensity.varinfo - # `mld.u` will contain the modes for any marginalized parameters - full_params = mld.u - # We can then set the values for any non-marginalized parameters + # Extract the stored parameters, which includes the modes for any marginalized + # parameters + full_params = MarginalLogDensities.cached_params(mld) + # We can then (if needed) set the values for any non-marginalized parameters if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end From 29840be18dc760c43fe7b0a59113bddf7906b852 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 12:31:15 +0100 Subject: [PATCH 12/16] Add more detailed changelog Co-authored-by: Sam Urmy --- HISTORY.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index b2983a426..71b35e749 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -3,7 +3,12 @@ ## 0.37.3 An extension for MarginalLogDensities.jl has been added. -Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model; please see the documentation for further information. + +Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model. +This is useful for averaging out random effects or nuisance parameters while improving inference on fixed effects/parameters of interest. +The `marginalize` function returns a `MarginalLogDensities.MarginalLogDensity`, a function-like callable struct that returns the approximate log-density of a subset of the parameters after integrating out the rest of them. +By default, this uses the Laplace approximation and sparse AD, making the marginalization computationally very efficient. +Please see [the documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalization) for further information. ## 0.37.2 From 06599cf8920f416eab77271768eaa1eefc4e4434 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 19 Sep 2025 10:34:44 +0100 Subject: [PATCH 13/16] Add error hint if marginalize is called before loading MLD --- src/DynamicPPL.jl | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f67cec6b7..bdc953a12 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -200,9 +200,9 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -# Better error message if users forget to load JET if isdefined(Base.Experimental, :register_error_hint) function __init__() + # Better error message if users forget to load JET.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_jet = exc.f === DynamicPPL.Experimental._determine_varinfo_jet && @@ -223,6 +223,23 @@ if isdefined(Base.Experimental, :register_error_hint) end end + # Same for MarginalLogDensities.jl + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + requires_mld = + exc.f === DynamicPPL.marginalize && + length(argtypes) == 2 && + argtypes[1] <: Model && + argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}} + if requires_mld + printstyled( + io, + "\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n"; + color=:cyan, + bold=true, + ) + end + end + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ is_evaluate_three_arg = exc.f === AbstractPPL.evaluate!! && From 015534bc8d0be02fef17f11cc9df0f5f7c3262f7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 19 Sep 2025 10:43:24 +0100 Subject: [PATCH 14/16] fix comma -> semicolon typo --- ext/DynamicPPLMarginalLogDensitiesExt.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 38dd67167..2155fa161 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -3,9 +3,6 @@ module DynamicPPLMarginalLogDensitiesExt using DynamicPPL: DynamicPPL, LogDensityProblems, VarName using MarginalLogDensities: MarginalLogDensities -_to_varname(n::Symbol) = VarName{n}() -_to_varname(n::VarName) = n - # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. @@ -19,7 +16,7 @@ end """ marginalize( model::DynamicPPL.Model, - varnames::AbstractVector{<:Union{Symbol,<:VarName}}; + marginalized_varnames::AbstractVector{<:VarName}; varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), getlogprob=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); @@ -93,15 +90,14 @@ julia> logpdf(Normal(2.0), 1.0) """ function DynamicPPL.marginalize( model::DynamicPPL.Model, - varnames::AbstractVector{<:Union{Symbol,<:VarName}}; + marginalized_varnames::AbstractVector{<:VarName}; varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), getlogprob::Function=DynamicPPL.getlogjoint, method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), kwargs..., ) # Determine the indices for the variables to marginalise out. - vns = map(_to_varname, varnames) - varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( From 206d6610ce936ada28d52c53de104e8596ce39ae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 19 Sep 2025 11:05:56 +0100 Subject: [PATCH 15/16] remove test with symbol --- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 4822ec8e6..32c4bb479 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -14,18 +14,16 @@ using ADTypes: AutoForwardDiff model = demo() vi = VarInfo(model) # Marginalize out `x`. - for vn in [@varname(x), :x] - for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] - marginalized = marginalize( - model, - [vn]; - varinfo=vi, - getlogprob=getlogprob, - hess_adtype=AutoForwardDiff(), - ) - for y in range(-5, 5; length=100) - @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 - end + @testset for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] + marginalized = marginalize( + model, + [@varname(x)]; + varinfo=vi, + getlogprob=getlogprob, + hess_adtype=AutoForwardDiff(), + ) + for y in range(-5, 5; length=100) + @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 end end end From 2b1d5e3ed966493a62194de8a6927f146e99035c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Sep 2025 13:31:57 +0100 Subject: [PATCH 16/16] Update changelog, use -ise in prose --- HISTORY.md | 8 +++++--- docs/src/api.md | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0c416b277..9c85674c3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,11 +4,13 @@ An extension for MarginalLogDensities.jl has been added. -Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model. +Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalise out variables from a model. This is useful for averaging out random effects or nuisance parameters while improving inference on fixed effects/parameters of interest. The `marginalize` function returns a `MarginalLogDensities.MarginalLogDensity`, a function-like callable struct that returns the approximate log-density of a subset of the parameters after integrating out the rest of them. -By default, this uses the Laplace approximation and sparse AD, making the marginalization computationally very efficient. -Please see [the documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalization) for further information. +By default, this uses the Laplace approximation and sparse AD, making the marginalisation computationally very efficient. +Note that the Laplace approximation relies on the model being differentiable with respect to the marginalised variables, and that their posteriors are unimodal and approximately Gaussian. + +Please see [the MarginalLogDensities documentation](https://eloceanografo.github.io/MarginalLogDensities.jl/stable) and the [new Marginalisation section of the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalisation) for further information. ## 0.37.3 diff --git a/docs/src/api.md b/docs/src/api.md index db7a8d3b5..999bbe822 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -136,16 +136,16 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a - `include_all=false` (default): Include only newly predicted variables - `include_all=true`: Include both parameters from the original chain and predicted variables -## Marginalization +## Marginalisation -DynamicPPL provides the `marginalize` function to marginalize out variables from a model. +DynamicPPL provides the `marginalize` function to marginalise out variables from a model. This requires `MarginalLogDensities.jl` to be loaded in your environment. ```@docs marginalize ``` -A `MarginalLogDensity` object acts as a function which maps non-marginalized parameter values to a marginal log-probability. +A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability. To retrieve a VarInfo object from it, you can use: ```@docs