Skip to content

Commit 6d5dc40

Browse files
committed
Moved LogDensityFunction from Turing to DPPL (#447)
Related to TuringLang/Turing.jl#1936
1 parent a5e9efb commit 6d5dc40

16 files changed

+217
-54
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.21.4"
3+
version = "0.21.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1516
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1617
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -28,6 +29,7 @@ ChainRulesCore = "0.9.7, 0.10, 1"
2829
ConstructionBase = "1"
2930
Distributions = "0.23.8, 0.24, 0.25"
3031
DocStringExtensions = "0.8, 0.9"
32+
LogDensityProblems = "2"
3133
MacroTools = "0.5.6"
3234
OrderedCollections = "1"
3335
Setfield = "0.7.1, 0.8, 1"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
6+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
67
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
78
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
89
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -12,6 +13,7 @@ DataStructures = "0.18"
1213
Distributions = "0.25"
1314
Documenter = "0.27"
1415
FillArrays = "0.13"
16+
LogDensityProblems = "2"
1517
MLUtils = "0.3, 0.4"
1618
Setfield = "0.7.1, 0.8, 1"
1719
StableRNGs = "1"

docs/src/api.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ loglikelihood
5858
logjoint
5959
```
6060

61+
### LogDensityProblems.jl interface
62+
63+
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
64+
65+
```@docs
66+
DynamicPPL.LogDensityFunction
67+
```
68+
6169
## Condition and decondition
6270

6371
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
@@ -133,6 +141,9 @@ Finally, the following methods can also be of use:
133141
```@docs
134142
DynamicPPL.TestUtils.varnames
135143
DynamicPPL.TestUtils.posterior_mean
144+
DynamicPPL.TestUtils.setup_varinfos
145+
DynamicPPL.TestUtils.update_values!!
146+
DynamicPPL.TestUtils.test_values
136147
```
137148

138149
## Advanced

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using MacroTools: MacroTools
1313
using ConstructionBase: ConstructionBase
1414
using Setfield: Setfield
1515
using ZygoteRules: ZygoteRules
16+
using LogDensityProblems: LogDensityProblems
1617

1718
using DocStringExtensions
1819

@@ -163,5 +164,6 @@ include("loglikelihoods.jl")
163164
include("submodel_macro.jl")
164165
include("test_utils.jl")
165166
include("transforming.jl")
167+
include("logdensityfunction.jl")
166168

167169
end # module

src/abstract_varinfo.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,14 +519,25 @@ end
519519

520520
# Utilities
521521
"""
522-
unflatten(vi::AbstractVarInfo[, spl::AbstractSampler], x::AbstractVector)
522+
unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector)
523523
524524
Return a new instance of `vi` with the values of `x` assigned to the variables.
525525
526-
If `spl` is provided, `x` is assumed to be realizations only for variables related
527-
to `spl`.
526+
If `context` is provided, `x` is assumed to be realizations only for variables not
527+
filtered out by `context`.
528528
"""
529-
function unflatten end
529+
function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ)
530+
if hassampler(context)
531+
unflatten(getsampler(context), varinfo, context, θ)
532+
else
533+
DynamicPPL.unflatten(varinfo, θ)
534+
end
535+
end
536+
537+
# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables.
538+
function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ)
539+
return unflatten(varinfo, sampler, θ)
540+
end
530541

531542
"""
532543
tonamedtuple(vi::AbstractVarInfo)

src/contexts.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,28 @@ function setchildcontext(parent::SamplingContext, child)
163163
return SamplingContext(parent.rng, parent.sampler, child)
164164
end
165165

166+
"""
167+
hassampler(context)
168+
169+
Return `true` if `context` has a sampler.
170+
"""
171+
hassampler(::SamplingContext) = true
172+
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
173+
hassampler(::IsLeaf, context::AbstractContext) = false
174+
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
175+
176+
"""
177+
getsampler(context)
178+
179+
Return the sampler of the context `context`.
180+
181+
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
182+
at which point it will return the sampler of that context.
183+
"""
184+
getsampler(context::SamplingContext) = context.sampler
185+
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
186+
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
187+
166188
"""
167189
struct DefaultContext <: AbstractContext end
168190

src/logdensityfunction.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
LogDensityFunction
3+
4+
A callable representing a log density function of a `model`.
5+
6+
# Fields
7+
$(FIELDS)
8+
9+
# Examples
10+
```jldoctest
11+
julia> using Distributions
12+
13+
julia> using DynamicPPL: LogDensityFunction
14+
15+
julia> @model function demo(x)
16+
m ~ Normal()
17+
x ~ Normal(m, 1)
18+
end
19+
demo (generic function with 2 methods)
20+
21+
julia> model = demo(1.0);
22+
23+
julia> f = LogDensityFunction(model);
24+
25+
julia> # It implements the interface of LogDensityProblems.jl.
26+
using LogDensityProblems
27+
28+
julia> LogDensityProblems.logdensity(f, [0.0])
29+
-2.3378770664093453
30+
31+
julia> LogDensityProblems.dimension(f)
32+
1
33+
34+
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
35+
f = LogDensityFunction(model, SimpleVarInfo(model));
36+
37+
julia> LogDensityProblems.logdensity(f, [0.0])
38+
-2.3378770664093453
39+
```
40+
"""
41+
struct LogDensityFunction{V,M,C}
42+
"varinfo used for evaluation"
43+
varinfo::V
44+
"model used for evaluation"
45+
model::M
46+
"context used for evaluation"
47+
context::C
48+
end
49+
50+
# TODO: Deprecate.
51+
function LogDensityFunction(
52+
varinfo::AbstractVarInfo,
53+
model::Model,
54+
sampler::AbstractSampler,
55+
context::AbstractContext,
56+
)
57+
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
58+
end
59+
60+
function LogDensityFunction(
61+
model::Model,
62+
varinfo::AbstractVarInfo=VarInfo(model),
63+
context::AbstractContext=DefaultContext(),
64+
)
65+
return LogDensityFunction(varinfo, model, context)
66+
end
67+
68+
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
69+
# we need to define these annoying methods to ensure that we stay compatible with everything.
70+
getsampler(f::LogDensityFunction) = getsampler(f.context)
71+
hassampler(f::LogDensityFunction) = hassampler(f.context)
72+
73+
_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
74+
_get_indexer(ctx::SamplingContext) = ctx.sampler
75+
_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx))
76+
_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()
77+
78+
"""
79+
getparams(f::LogDensityFunction)
80+
81+
Return the parameters of the wrapped varinfo as a vector.
82+
"""
83+
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)]
84+
85+
# LogDensityProblems interface
86+
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
87+
vi_new = unflatten(f.varinfo, f.context, θ)
88+
return getlogp(last(evaluate!!(f.model, vi_new, f.context)))
89+
end
90+
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
91+
return LogDensityProblems.LogDensityOrder{0}()
92+
end
93+
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
94+
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ function SimpleVarInfo{T}(
248248
return SimpleVarInfo(values, convert(T, getlogp(vi)))
249249
end
250250

251-
unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x)
251+
unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x)
252252
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
253253
return Setfield.@set svi.values = unflatten(svi.values, x)
254254
end

src/test_utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,50 @@ function varname_leaves(vn::VarName, val::AbstractArray)
3131
)
3232
end
3333

34+
"""
35+
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
36+
37+
Return instance similar to `vi` but with `vns` set to values from `vals`.
38+
"""
39+
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
40+
for vn in vns
41+
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
42+
end
43+
return vi
44+
end
45+
46+
"""
47+
test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
48+
49+
Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
50+
"""
51+
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal, kwargs...)
52+
for vn in vns
53+
@test isequal(vi[vn], get(vals, vn); kwargs...)
54+
end
55+
end
56+
57+
"""
58+
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
59+
60+
Return a tuple of instances for different implementations of `AbstractVarInfo` with
61+
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
62+
"""
63+
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
64+
# <:VarInfo
65+
vi_untyped = VarInfo()
66+
model(vi_untyped)
67+
vi_typed = DynamicPPL.TypedVarInfo(vi_untyped)
68+
# <:SimpleVarInfo
69+
svi_typed = SimpleVarInfo(example_values)
70+
svi_untyped = SimpleVarInfo(OrderedDict())
71+
72+
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
73+
# Set them all to the same values.
74+
update_values!!(vi, example_values, varnames)
75+
end
76+
end
77+
3478
"""
3579
logprior_true(model, args...)
3680

src/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
140140
unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)
141141

142142
# TODO: deprecate.
143-
unflatten(vi::VarInfo, spl, x::AbstractVector) = VarInfo(vi, spl, x)
143+
unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) = VarInfo(vi, spl, x)
144144

145145
# without AbstractSampler
146146
function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext)

0 commit comments

Comments
 (0)