Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.4

Added compatibility with DifferentiationInterface.jl 0.7.

## 0.36.3

Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.36.3"
version = "0.36.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -54,7 +54,7 @@ ChainRulesCore = "1"
Chairmarks = "1.3.1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
DifferentiationInterface = "0.6.41, 0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
Expand Down
27 changes: 23 additions & 4 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
if use_closure(adtype)
prep = DI.prepare_gradient(
x -> logdensity_at(x, model, varinfo, context), adtype, x
)
prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x)

Check warning on line 127 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L127

Added line #L127 was not covered by tests
else
prep = DI.prepare_gradient(
logdensity_at,
Expand Down Expand Up @@ -184,6 +182,27 @@
return getlogp(last(evaluate!!(model, varinfo_new, context)))
end

"""
LogDensityAt(
x::AbstractVector,
model::Model,
varinfo::AbstractVarInfo,
context::AbstractContext
)
A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
varinfo, context)`.
"""
struct LogDensityAt
model::Model
varinfo::AbstractVarInfo
context::AbstractContext
end
function (ld::LogDensityAt)(x::AbstractVector)
varinfo_new = unflatten(ld.varinfo, x)
return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context)))

Check warning on line 203 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L201-L203

Added lines #L201 - L203 were not covered by tests
end

### LogDensityProblems interface

function LogDensityProblems.capabilities(
Expand All @@ -209,7 +228,7 @@
# branches happen to return different types)
return if use_closure(f.adtype)
DI.value_and_gradient(
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x
)
else
DI.value_and_gradient(
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Aqua = "0.8"
Bijectors = "0.15.1"
Combinatorics = "1"
Compat = "4.3.0"
DifferentiationInterface = "0.6.41"
DifferentiationInterface = "0.6.41, 0.7"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "1"
Expand Down
Loading