@@ -17,50 +17,13 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
17
17
return transition_to_turing (parent (f), transition)
18
18
end
19
19
20
- """
21
- getmodel(f)
22
-
23
- Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
24
- """
25
- getmodel (f:: LogDensityProblemsAD.ADGradientWrapper ) = getmodel (parent (f))
26
- getmodel (f:: DynamicPPL.LogDensityFunction ) = f. model
27
-
28
- """
29
- setmodel(f, model[, adtype])
30
-
31
- Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
32
-
33
- !!! warning
34
- Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
35
- `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
36
- might require recompilation of the gradient tape, depending on the AD backend.
37
- """
38
- function setmodel (
39
- f:: LogDensityProblemsAD.ADGradientWrapper ,
40
- model:: DynamicPPL.Model ,
41
- adtype:: ADTypes.AbstractADType
42
- )
43
- # TODO : Should we handle `SciMLBase.NoAD`?
44
- # For an `ADGradientWrapper` we do the following:
45
- # 1. Update the `Model` in the underlying `LogDensityFunction`.
46
- # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
47
- # to ensure that the recompilation of gradient tapes, etc. also occur. For example,
48
- # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
49
- # replacing the corresponding field with the new model won't be sufficient to obtain
50
- # the correct gradients.
51
- return LogDensityProblemsAD. ADgradient (adtype, setmodel (parent (f), model))
52
- end
53
- function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
54
- return Accessors. @set f. model = model
55
- end
56
-
57
20
function varinfo_from_logdensityfn (f:: LogDensityProblemsAD.ADGradientWrapper )
58
21
return varinfo_from_logdensityfn (parent (f))
59
22
end
60
23
varinfo_from_logdensityfn (f:: DynamicPPL.LogDensityFunction ) = f. varinfo
61
24
62
25
function varinfo (state:: TuringState )
63
- θ = getparams (getmodel (state. logdensity), state. state)
26
+ θ = getparams (DynamicPPL . getmodel (state. logdensity), state. state)
64
27
# TODO : Do we need to link here first?
65
28
return DynamicPPL. unflatten (varinfo_from_logdensityfn (state. logdensity), θ)
66
29
end
@@ -97,7 +60,7 @@ function recompute_logprob!!(
97
60
)
98
61
# Re-using the log-density function from the `state` and updating only the `model` field,
99
62
# since the `model` might now contain different conditioning values.
100
- f = setmodel (state. logdensity, model, sampler. alg. adtype)
63
+ f = DynamicPPL . setmodel (state. logdensity, model, sampler. alg. adtype)
101
64
# Recompute the log-probability with the new `model`.
102
65
state_inner = recompute_logprob!! (
103
66
rng, AbstractMCMC. LogDensityModel (f), sampler. alg. sampler, state. state
0 commit comments