-
Notifications
You must be signed in to change notification settings - Fork 228
Extension for MarginalLogDensities.jl, take 2 #2664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
module TuringMarginalLogDensitiesExt | ||
|
||
using Turing: Turing, DynamicPPL | ||
using Turing.Inference: LogDensityProblems | ||
using MarginalLogDensities: MarginalLogDensities | ||
|
||
|
||
# Use a struct for this to avoid closure overhead. | ||
struct Drop2ndArgAndFlipSign{F} | ||
f::F | ||
end | ||
|
||
(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x) | ||
|
||
_to_varname(n::Symbol) = DynamicPPL.@varname($n) | ||
_to_varname(n::DynamicPPL.AbstractPPL.VarName) = n | ||
|
||
function Turing.marginalize( | ||
model::DynamicPPL.Model, | ||
varnames::Vector, | ||
method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(), | ||
) | ||
# Determine the indices for the variables to marginalise out. | ||
varinfo = DynamicPPL.typed_varinfo(model) | ||
vns = _to_varname.(varnames) | ||
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) | ||
# Construct the marginal log-density model. | ||
# Use linked `varinfo` to that we're working in unconstrained space | ||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||
|
||
f = Turing.Optimisation.OptimLogDensity( | ||
model, | ||
Turing.DynamicPPL.getlogjoint, | ||
# Turing.DynamicPPL.typed_varinfo(model) | ||
varinfo_linked | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just making a note here that I am pretty sure that there's a double negative sign somewhere here with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's still needed, @model function demo()
x ~ Normal()
end
model = demo()
f = Turing.Optimisation.OptimLogDensity(
model,
DynamicPPL.getlogjoint,
DynamicPPL.typed_varinfo(model)
)
f([0]) # 0.9189385332046728
logpdf(Normal(), 0) #-0.9189385332046728 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, but that's the fault of OptimLogDensity not getlogjoint :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not asking you to fix it btw- happy to do it when I'm back to writing code! |
||
|
||
# HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which | ||
# represent the _negative_ log-density. | ||
mdl = MarginalLogDensities.MarginalLogDensity( | ||
Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method | ||
) | ||
return mdl | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From reading the MLD readme, I gather that you can do a couple of things with this:
Would it also be useful to be able to perform MCMC sampling with this? In principle it's not too difficult: a function that takes a vector of parameters, and returns a float, is pretty much what NUTS needs. To do it properly will probably take a bit of time: I think we would have to either change the interface of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should already work in theory, since MLD objects implement the LogDensityProblems interface (here). Would be good to add a test for it though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never mind, I thought it was "just working" but it isn't. Shouldn't be too hard to with some kind of simple wrapper though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yuppp, the samplers right now only work with DynamicPPL.Model. We should broaden those types and make an interface but it's quite a large undertaking. Do you know if AD would be able to differentiate the marginal log-density? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not yet, it's the last major feature I need to add to MLD. It's a bit tricky since you need to differentiate through the Hessian of an optimization result, but it should be possible. |
||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
function marginalize(model, varnames, method) | ||
error("This function is available after importing MarginalLogDensities.") | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
module TuringMarginalLogDensitiesTest | ||
|
||
using Turing, MarginalLogDensities, Test | ||
|
||
@testset "MarginalLogDensities" begin | ||
# Simple test case. | ||
@model function demo() | ||
x ~ MvNormal(zeros(2), [1, 1]) | ||
y ~ Normal(0, 1) | ||
end | ||
model = demo(); | ||
# Marginalize out `x`. | ||
marginalized = marginalize(model, [@varname(x)]); | ||
marginalized = marginalize(model, [:x]); | ||
# Compute the marginal log-density of `y = 0.0`. | ||
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=2e-1 | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.