Skip to content

Commit b9f2cc0

Browse files
authored
Remove getmodel and setmodel after transferred to DynamicPPL (#2292)
* remove `getmodel` and `setmodel` after transferred to DynamicPPL * bump version
1 parent 29a1342 commit b9f2cc0

File tree

2 files changed

+4
-41
lines changed

2 files changed

+4
-41
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.33.2"
3+
version = "0.33.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
6363
DistributionsAD = "0.6"
6464
DocStringExtensions = "0.8, 0.9"
6565
DynamicHMC = "3.4"
66-
DynamicPPL = "0.28.1"
66+
DynamicPPL = "0.28.2"
6767
Compat = "4.15.0"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3"

src/mcmc/abstractmcmc.jl

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,13 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
1717
return transition_to_turing(parent(f), transition)
1818
end
1919

20-
"""
21-
getmodel(f)
22-
23-
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
24-
"""
25-
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
26-
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
27-
28-
"""
29-
setmodel(f, model[, adtype])
30-
31-
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
32-
33-
!!! warning
34-
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
35-
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
36-
might require recompilation of the gradient tape, depending on the AD backend.
37-
"""
38-
function setmodel(
39-
f::LogDensityProblemsAD.ADGradientWrapper,
40-
model::DynamicPPL.Model,
41-
adtype::ADTypes.AbstractADType
42-
)
43-
# TODO: Should we handle `SciMLBase.NoAD`?
44-
# For an `ADGradientWrapper` we do the following:
45-
# 1. Update the `Model` in the underlying `LogDensityFunction`.
46-
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
47-
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
48-
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
49-
# replacing the corresponding field with the new model won't be sufficient to obtain
50-
# the correct gradients.
51-
return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model))
52-
end
53-
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
54-
return Accessors.@set f.model = model
55-
end
56-
5720
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
5821
return varinfo_from_logdensityfn(parent(f))
5922
end
6023
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo
6124

6225
function varinfo(state::TuringState)
63-
θ = getparams(getmodel(state.logdensity), state.state)
26+
θ = getparams(DynamicPPL.getmodel(state.logdensity), state.state)
6427
# TODO: Do we need to link here first?
6528
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
6629
end
@@ -97,7 +60,7 @@ function recompute_logprob!!(
9760
)
9861
# Re-using the log-density function from the `state` and updating only the `model` field,
9962
# since the `model` might now contain different conditioning values.
100-
f = setmodel(state.logdensity, model, sampler.alg.adtype)
63+
f = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype)
10164
# Recompute the log-probability with the new `model`.
10265
state_inner = recompute_logprob!!(
10366
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state

0 commit comments

Comments
 (0)