diff --git a/HISTORY.md b/HISTORY.md index 4e9bc2d42..26eaa2d39 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.36.4 + +Added compatibility with DifferentiationInterface.jl 0.7. + ## 0.36.3 Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo. diff --git a/Project.toml b/Project.toml index 5bef5bcb1..128822367 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.3" +version = "0.36.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -54,7 +54,7 @@ ChainRulesCore = "1" Chairmarks = "1.3.1" Compat = "4" ConstructionBase = "1.5.4" -DifferentiationInterface = "0.6.41" +DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..443c435e0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -124,9 +124,7 @@ struct LogDensityFunction{ # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient( - x -> logdensity_at(x, model, varinfo, context), adtype, x - ) + prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) else prep = DI.prepare_gradient( logdensity_at, @@ -184,6 +182,26 @@ function logdensity_at( return getlogp(last(evaluate!!(model, varinfo_new, context))) end +""" + LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + model::M + varinfo::V + context::C + ) + +A callable struct that serves the same purpose as `x -> logdensity_at(x, model, +varinfo, context)`. +""" +struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} + model::M + varinfo::V + context::C +end +function (ld::LogDensityAt)(x::AbstractVector) + varinfo_new = unflatten(ld.varinfo, x) + return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context))) +end + ### LogDensityProblems interface function LogDensityProblems.capabilities( @@ -209,7 +227,7 @@ function LogDensityProblems.logdensity_and_gradient( # branches happen to return different types) return if use_closure(f.adtype) DI.value_and_gradient( - x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x ) else DI.value_and_gradient( diff --git a/test/Project.toml b/test/Project.toml index 79e6d129b..92e81bb83 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -38,7 +38,7 @@ Aqua = "0.8" Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" -DifferentiationInterface = "0.6.41" +DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1"