Skip to content

Commit 408ddb1

Browse files
authored
use ::Any in LogDensityFunction (#1189)
1 parent c5a646f commit 408ddb1

File tree

4 files changed

+34
-9
lines changed

4 files changed

+34
-9
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.8
4+
5+
Allow the `getlogdensity` argument of `LogDensityFunction` to accept callable structs as well as functions.
6+
37
## 0.39.7
48

59
Improve concreteness when merging two `Metadata` structs.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.7"
3+
version = "0.39.8"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/logdensityfunction.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Random: Random
3232
"""
3333
DynamicPPL.LogDensityFunction(
3434
model::Model,
35-
getlogdensity::Function=getlogjoint_internal,
35+
getlogdensity::Any=getlogjoint_internal,
3636
varinfo::AbstractVarInfo=VarInfo(model);
3737
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
3838
)
@@ -47,7 +47,9 @@ using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gra
4747
`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD
4848
backend type, then `logdensity_and_gradient` is also implemented.
4949
50-
There are several options for `getlogdensity` that are 'supported' out of the box:
50+
`getlogdensity` should be a callable which takes a single argument: a `VarInfo`, and returns
51+
a `Real` corresponding to the log density of interest. There are several functions in
52+
DynamicPPL that are 'supported' out of the box:
5153
5254
- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term
5355
for any variables that have been linked in the provided VarInfo.
@@ -145,7 +147,7 @@ struct LogDensityFunction{
145147
Tlink,
146148
M<:Model,
147149
AD<:Union{ADTypes.AbstractADType,Nothing},
148-
F<:Function,
150+
F,
149151
N<:NamedTuple,
150152
ADP<:Union{Nothing,DI.GradientPrep},
151153
# type of the vector passed to logdensity functions
@@ -161,7 +163,7 @@ struct LogDensityFunction{
161163

162164
function LogDensityFunction(
163165
model::Model,
164-
getlogdensity::Function=getlogjoint_internal,
166+
getlogdensity::Any=getlogjoint_internal,
165167
varinfo::AbstractVarInfo=VarInfo(model);
166168
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
167169
)
@@ -219,12 +221,12 @@ end
219221
# LogDensityProblems.jl interface #
220222
###################################
221223
"""
222-
ldf_accs(getlogdensity::Function)
224+
ldf_accs(getlogdensity::Any)
223225
224226
Determine which accumulators are needed for fast evaluation with the given
225-
`getlogdensity` function.
227+
`getlogdensity` callable.
226228
"""
227-
ldf_accs(::Function) = default_accumulators()
229+
ldf_accs(::Any) = default_accumulators()
228230
ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
229231
function ldf_accs(::typeof(getlogjoint))
230232
return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
@@ -235,7 +237,7 @@ end
235237
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
236238
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
237239

238-
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
240+
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
239241
model::M
240242
getlogdensity::F
241243
iden_varname_ranges::N

test/logdensityfunction.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,25 @@ end
104104
@test LogDensityProblems.capabilities(typeof(ldf)) ==
105105
LogDensityProblems.LogDensityOrder{1}()
106106
end
107+
108+
@testset "Callable struct as getlogdensity" begin
109+
@model function f()
110+
x ~ Normal()
111+
return 1.0 ~ Normal(x)
112+
end
113+
struct ScaledLogLike
114+
scale::Float64
115+
end
116+
function (sll::ScaledLogLike)(vi::AbstractVarInfo)
117+
return sll.scale * getloglikelihood(vi)
118+
end
119+
model = f()
120+
vi = VarInfo(model)
121+
sll = ScaledLogLike(2.0)
122+
ldf = DynamicPPL.LogDensityFunction(model, sll, vi)
123+
x = vi[:]
124+
@test LogDensityProblems.logdensity(ldf, x) == sll.scale * logpdf(Normal(x[1]), 1.0)
125+
end
107126
end
108127

109128
@testset "LogDensityFunction: Type stability" begin

0 commit comments

Comments
 (0)