Skip to content

Commit 3b1a6bc

Browse files
committed
Merge remote-tracking branch 'origin/main' into py/no-mooncake-pre
2 parents 6974ab2 + 072234d commit 3b1a6bc

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

HISTORY.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
## 0.36.4
44

5-
Added a compatibility entry for JET.jl 0.10.
6-
This should only affect you if you are using DynamicPPL on the Julia 1.12 pre-release.
5+
Added compatibility with DifferentiationInterface.jl 0.7, and also with JET.jl 0.10.
6+
7+
The JET compatibility entry should only affect you if you are using DynamicPPL on the Julia 1.12 pre-release.
78

89
## 0.36.3
910

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ChainRulesCore = "1"
5454
Chairmarks = "1.3.1"
5555
Compat = "4"
5656
ConstructionBase = "1.5.4"
57-
DifferentiationInterface = "0.6.41"
57+
DifferentiationInterface = "0.6.41, 0.7"
5858
Distributions = "0.25"
5959
DocStringExtensions = "0.9"
6060
EnzymeCore = "0.6 - 0.8"

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Return the dimension of `model`, accounting for linking, if any.
2323
"""
2424
function model_dimension(model, islinked)
2525
vi = VarInfo()
26-
model(vi)
26+
model(StableRNG(23), vi)
2727
if islinked
2828
vi = DynamicPPL.link(vi, model)
2929
end

src/logdensityfunction.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ struct LogDensityFunction{
124124
# Get a set of dummy params to use for prep
125125
x = map(identity, varinfo[:])
126126
if use_closure(adtype)
127-
prep = DI.prepare_gradient(
128-
x -> logdensity_at(x, model, varinfo, context), adtype, x
129-
)
127+
prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x)
130128
else
131129
prep = DI.prepare_gradient(
132130
logdensity_at,
@@ -184,6 +182,26 @@ function logdensity_at(
184182
return getlogp(last(evaluate!!(model, varinfo_new, context)))
185183
end
186184

185+
"""
186+
LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}(
187+
model::M
188+
varinfo::V
189+
context::C
190+
)
191+
192+
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193+
varinfo, context)`.
194+
"""
195+
struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
196+
model::M
197+
varinfo::V
198+
context::C
199+
end
200+
function (ld::LogDensityAt)(x::AbstractVector)
201+
varinfo_new = unflatten(ld.varinfo, x)
202+
return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context)))
203+
end
204+
187205
### LogDensityProblems interface
188206

189207
function LogDensityProblems.capabilities(
@@ -209,7 +227,7 @@ function LogDensityProblems.logdensity_and_gradient(
209227
# branches happen to return different types)
210228
return if use_closure(f.adtype)
211229
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
230+
LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x
213231
)
214232
else
215233
DI.value_and_gradient(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Aqua = "0.8"
3636
Bijectors = "0.15.1"
3737
Combinatorics = "1"
3838
Compat = "4.3.0"
39-
DifferentiationInterface = "0.6.41"
39+
DifferentiationInterface = "0.6.41, 0.7"
4040
Distributions = "0.25"
4141
DistributionsAD = "0.6.3"
4242
EnzymeCore = "0.6 - 0.8"

0 commit comments

Comments
 (0)