Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
TuringDynamicHMCExt = "DynamicHMC"
TuringMarginalLogDensitiesExt = "MarginalLogDensities"
TuringOptimExt = "Optim"

[compat]
Expand All @@ -71,6 +73,7 @@ Libtask = "0.9.3"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
MarginalLogDensities = "0.4.1"
NamedArrays = "0.9, 0.10"
Optim = "1"
Optimization = "3, 4"
Expand All @@ -89,4 +92,4 @@ julia = "1.10.8"

[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
46 changes: 46 additions & 0 deletions ext/TuringMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module TuringMarginalLogDensitiesExt

using Turing: Turing, DynamicPPL
using Turing.Inference: LogDensityProblems
using MarginalLogDensities: MarginalLogDensities


# Use a struct for this to avoid closure overhead.
struct Drop2ndArgAndFlipSign{F}
f::F
end

(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x)

_to_varname(n::Symbol) = DynamicPPL.@varname($n)
_to_varname(n::DynamicPPL.AbstractPPL.VarName) = n

function Turing.marginalize(
model::DynamicPPL.Model,
varnames::Vector,
method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(),
)
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
vns = _to_varname.(varnames)
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
# Construct the marginal log-density model.
# Use linked `varinfo` to that we're working in unconstrained space
varinfo_linked = DynamicPPL.link(varinfo, model)

f = Turing.Optimisation.OptimLogDensity(
model,
Turing.DynamicPPL.getlogjoint,
# Turing.DynamicPPL.typed_varinfo(model)
varinfo_linked
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making a note here that I am pretty sure that there's a double negative sign somewhere here with OptimLogDensity and the FlipSign thing, so I think we will be able to simplify this. (Pre-DynamicPPL 0.37, it used to be that OptimLogDensity was the only way to get the log-joint without the Jacobian term, so I'm guessing that's why it got used here, but now we can do that more easily with getlogjoint.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still needed, getlogjoint appears to return a negative log-density:

@model function demo()
       x ~ Normal()
end
model = demo()
f = Turing.Optimisation.OptimLogDensity(
        model,
        DynamicPPL.getlogjoint,
        DynamicPPL.typed_varinfo(model)
)
f([0]) # 0.9189385332046728
logpdf(Normal(), 0) #-0.9189385332046728

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but that's the fault of OptimLogDensity not getlogjoint :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not asking you to fix it btw- happy to do it when I'm back to writing code!


# HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which
# represent the _negative_ log-density.
mdl = MarginalLogDensities.MarginalLogDensity(
Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method
)
return mdl
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading the MLD readme, I gather that you can do a couple of things with this:

  1. Evaluate the marginal log density by calling it
  2. Perform optimisation to get point estimates

Would it also be useful to be able to perform MCMC sampling with this? In principle it's not too difficult: a function that takes a vector of parameters, and returns a float, is pretty much what NUTS needs.

To do it properly will probably take a bit of time: I think we would have to either change the interface of DynamicPPL.LogDensityFunction or (perhaps easier) do a wrapper around it that keeps track of which variables are marginalised out. So it wouldn't have to be part of this PR. But if this is worth doing then I could have a think about how to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should already work in theory, since MLD objects implement the LogDensityProblems interface (here). Would be good to add a test for it though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I thought it was "just working" but it isn't. Shouldn't be too hard to with some kind of simple wrapper though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yuppp, the samplers right now only work with DynamicPPL.Model. We should broaden those types and make an interface but it's quite a large undertaking.

Do you know if AD would be able to differentiate the marginal log-density?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, it's the last major feature I need to add to MLD. It's a bit tricky since you need to differentiate through the Hessian of an optimization result, but it should be possible.


end
6 changes: 5 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ using .Variational
include("optimisation/Optimisation.jl")
using .Optimisation

include("extensions.jl")

###########
# Exports #
###########
Expand Down Expand Up @@ -153,6 +155,8 @@ export
maximum_a_posteriori,
maximum_likelihood,
MAP,
MLE
MLE,
# MarginalLogDensities extension
marginalize

end
3 changes: 3 additions & 0 deletions src/extensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function marginalize(model, varnames, method)
error("This function is available after importing MarginalLogDensities.")
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked through MLD's and Turing's deps, and the only new dep would be HCubature which is a pretty small dep. So I might be in favour to just making this part of the source code itself instead of an extension. (I think Julia extensions are a really good thing in general but there are a few downsides to them, like having to do this kind of 'define in main package and extend in extension' stuff.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this one's up to you guys. I don't mind either way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I'm back at work. I started trying to move some code, then I realised that nothing in this actually needs Turing. So I'm fairly certain that this should be in DynamicPPL (as an extension) rather than in Turing. I'm going to make a PR to DynamicPPL with the code shuffled around (and with the log-density sign thing fixed), will ping you over there :)

19 changes: 19 additions & 0 deletions test/ext/marginallogdensities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module TuringMarginalLogDensitiesTest

using Turing, MarginalLogDensities, Test

@testset "MarginalLogDensities" begin
# Simple test case.
@model function demo()
x ~ MvNormal(zeros(2), [1, 1])
y ~ Normal(0, 1)
end
model = demo();
# Marginalize out `x`.
marginalized = marginalize(model, [@varname(x)]);
marginalized = marginalize(model, [:x]);
# Compute the marginal log-density of `y = 0.0`.
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=2e-1
end

end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ end
@timeit_include("optimisation/Optimisation.jl")
@timeit_include("ext/OptimInterface.jl")
end

@testset "marginalization" verbose = true begin
@timeit_include("ext/marginallogdensities.jl")
end

end

@testset "stdlib" verbose = true begin
Expand Down
Loading