@@ -46,45 +46,42 @@ DynamicPPL.childcontext(context::OptimizationContext) = context.context
46
46
DynamicPPL. setchildcontext (:: OptimizationContext , child) = OptimizationContext (child)
47
47
48
48
# assume
49
- function DynamicPPL. tilde_assume (rng:: Random.AbstractRNG , ctx:: OptimizationContext , spl, dist, vn, vi)
50
- return DynamicPPL. tilde_assume (ctx, spl, dist, vn, vi)
51
- end
52
-
53
- function DynamicPPL. tilde_assume (ctx:: OptimizationContext{<:LikelihoodContext} , spl, dist, vn, vi)
54
- r = vi[vn]
49
+ function DynamicPPL. tilde_assume (ctx:: OptimizationContext{<:LikelihoodContext} , dist, vn, vi)
50
+ r = vi[vn, dist]
55
51
return r, 0 , vi
56
52
end
57
53
58
- function DynamicPPL. tilde_assume (ctx:: OptimizationContext , spl, dist, vn, vi)
59
- r = vi[vn]
54
+ function DynamicPPL. tilde_assume (ctx:: OptimizationContext , dist, vn, vi)
55
+ r = vi[vn, dist ]
60
56
return r, Distributions. logpdf (dist, r), vi
61
57
end
62
58
63
59
# dot assume
64
- function DynamicPPL. dot_tilde_assume (rng:: Random.AbstractRNG , ctx:: OptimizationContext , sampler, right, left, vns, vi)
65
- return DynamicPPL. dot_tilde_assume (ctx, sampler, right, left, vns, vi)
66
- end
67
-
68
- function DynamicPPL. dot_tilde_assume (ctx:: OptimizationContext{<:LikelihoodContext} , sampler:: SampleFromPrior , right, left, vns, vi)
60
+ function DynamicPPL. dot_tilde_assume (ctx:: OptimizationContext{<:LikelihoodContext} , right, left, vns, vi)
69
61
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
70
62
# affect anything.
71
- r = DynamicPPL. get_and_set_val! (Random. GLOBAL_RNG, vi, vns, right, sampler)
63
+ # TODO : Stop using `get_and_set_val!`.
64
+ r = DynamicPPL. get_and_set_val! (Random. default_rng (), vi, vns, right, SampleFromPrior ())
72
65
return r, 0 , vi
73
66
end
74
67
75
- function DynamicPPL. dot_tilde_assume (ctx:: OptimizationContext , sampler:: SampleFromPrior , right, left, vns, vi)
68
+ _loglikelihood (dist:: Distribution , x) = loglikelihood (dist, x)
69
+ _loglikelihood (dists:: AbstractArray{<:Distribution} , x) = loglikelihood (arraydist (dists), x)
70
+
71
+ function DynamicPPL. dot_tilde_assume (ctx:: OptimizationContext , right, left, vns, vi)
76
72
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
77
73
# affect anything.
78
- r = DynamicPPL. get_and_set_val! (Random. GLOBAL_RNG, vi, vns, right, sampler)
79
- return r, loglikelihood (right, r), vi
74
+ # TODO : Stop using `get_and_set_val!`.
75
+ r = DynamicPPL. get_and_set_val! (Random. default_rng (), vi, vns, right, SampleFromPrior ())
76
+ return r, _loglikelihood (right, r), vi
80
77
end
81
78
82
79
"""
83
80
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
84
81
85
82
A struct that stores the negative log density function of a `DynamicPPL` model.
86
83
"""
87
- const OptimLogDensity{M<: Model ,C<: OptimizationContext ,V<: VarInfo } = Turing. LogDensityFunction{V,M,DynamicPPL . SampleFromPrior, C}
84
+ const OptimLogDensity{M<: Model ,C<: OptimizationContext ,V<: VarInfo } = Turing. LogDensityFunction{V,M,C}
88
85
89
86
"""
90
87
OptimLogDensity(model::Model, context::OptimizationContext)
@@ -93,21 +90,23 @@ Create a callable `OptimLogDensity` struct that evaluates a model using the give
93
90
"""
94
91
function OptimLogDensity (model:: Model , context:: OptimizationContext )
95
92
init = VarInfo (model)
96
- return Turing. LogDensityFunction (init, model, DynamicPPL . SampleFromPrior (), context)
93
+ return Turing. LogDensityFunction (init, model, context)
97
94
end
98
95
99
96
"""
100
- (f::OptimLogDensity)( z)
97
+ LogDensityProblems.logdensity (f::OptimLogDensity, z)
101
98
102
99
Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
103
100
at the array `z`.
104
101
"""
105
102
function (f:: OptimLogDensity )(z:: AbstractVector )
106
- sampler = f. sampler
107
- varinfo = DynamicPPL. unflatten (f. varinfo, sampler, z)
108
- return - getlogp (last (DynamicPPL. evaluate!! (f. model, varinfo, sampler, f. context)))
103
+ varinfo = DynamicPPL. unflatten (f. varinfo, z)
104
+ return - getlogp (last (DynamicPPL. evaluate!! (f. model, varinfo, f. context)))
109
105
end
110
106
107
+ # NOTE: This seems a bit weird IMO since this is the _negative_ log-likelihood.
108
+ LogDensityProblems. logdensity (f:: OptimLogDensity , z:: AbstractVector ) = f (z)
109
+
111
110
function (f:: OptimLogDensity )(F, G, z)
112
111
if G != = nothing
113
112
# Calculate negative log joint and its gradient.
@@ -127,7 +126,7 @@ function (f::OptimLogDensity)(F, G, z)
127
126
128
127
# Only negative log joint requested but no gradient.
129
128
if F != = nothing
130
- return f ( z)
129
+ return LogDensityProblems . logdensity (f, z)
131
130
end
132
131
133
132
return nothing
@@ -140,50 +139,44 @@ end
140
139
# ################################################
141
140
142
141
function transform!! (f:: OptimLogDensity )
143
- spl = f. sampler
144
-
145
142
# # Check link status of vi in OptimLogDensity
146
- linked = DynamicPPL. islinked (f. varinfo, spl )
143
+ linked = DynamicPPL. istrans (f. varinfo)
147
144
148
145
# # transform into constrained or unconstrained space depending on current state of vi
149
146
@set! f. varinfo = if ! linked
150
- DynamicPPL. link!! (f. varinfo, spl, f. model)
147
+ DynamicPPL. link!! (f. varinfo, f. model)
151
148
else
152
- DynamicPPL. invlink!! (f. varinfo, spl, f. model)
149
+ DynamicPPL. invlink!! (f. varinfo, f. model)
153
150
end
154
151
155
152
return f
156
153
end
157
154
158
155
function transform!! (p:: AbstractArray , vi:: DynamicPPL.VarInfo , model:: DynamicPPL.Model , :: constrained_space{true} )
159
- spl = DynamicPPL. SampleFromPrior ()
160
-
161
- linked = DynamicPPL. islinked (vi, spl)
156
+ linked = DynamicPPL. istrans (vi)
162
157
163
158
! linked && return identity (p) # TODO : why do we do `identity` here?
164
- vi = DynamicPPL. setindex!! (vi, p, spl )
165
- vi = DynamicPPL. invlink!! (vi, spl, model)
166
- p .= vi[spl ]
159
+ vi = DynamicPPL. unflatten (vi, p)
160
+ vi = DynamicPPL. invlink!! (vi, model)
161
+ p .= vi[: ]
167
162
168
163
# If linking mutated, we need to link once more.
169
- linked && DynamicPPL. link!! (vi, spl, model)
164
+ linked && DynamicPPL. link!! (vi, model)
170
165
171
166
return p
172
167
end
173
168
174
169
function transform!! (p:: AbstractArray , vi:: DynamicPPL.VarInfo , model:: DynamicPPL.Model , :: constrained_space{false} )
175
- spl = DynamicPPL. SampleFromPrior ()
176
-
177
- linked = DynamicPPL. islinked (vi, spl)
170
+ linked = DynamicPPL. istrans (vi)
178
171
if linked
179
- vi = DynamicPPL. invlink!! (vi, spl, model)
172
+ vi = DynamicPPL. invlink!! (vi, model)
180
173
end
181
- vi = DynamicPPL. setindex!! (vi, p, spl )
182
- vi = DynamicPPL. link!! (vi, spl, model)
183
- p .= vi[spl ]
174
+ vi = DynamicPPL. unflatten (vi, p)
175
+ vi = DynamicPPL. link!! (vi, model)
176
+ p .= vi[: ]
184
177
185
178
# If linking mutated, we need to link once more.
186
- ! linked && DynamicPPL. invlink!! (vi, spl, model)
179
+ ! linked && DynamicPPL. invlink!! (vi, model)
187
180
188
181
return p
189
182
end
@@ -208,26 +201,26 @@ end
208
201
209
202
function (t:: AbstractTransform )(p:: AbstractArray )
210
203
return transform (p, t. vi, t. model, t. space)
211
- end
204
+ end
212
205
213
206
function (t:: Init )()
214
207
return t. vi[DynamicPPL. SampleFromPrior ()]
215
208
end
216
209
217
210
function get_parameter_bounds (model:: DynamicPPL.Model )
218
211
vi = DynamicPPL. VarInfo (model)
219
- spl = DynamicPPL. SampleFromPrior ()
220
212
221
213
# # Check link status of vi
222
- linked = DynamicPPL. islinked (vi, spl)
214
+ linked = DynamicPPL. istrans (vi)
223
215
224
216
# # transform into unconstrained
225
217
if ! linked
226
- vi = DynamicPPL. link!! (vi, spl, model)
218
+ vi = DynamicPPL. link!! (vi, model)
227
219
end
228
-
229
- lb = transform (fill (- Inf ,length (vi[DynamicPPL. SampleFromPrior ()])), vi, model, constrained_space {true} ())
230
- ub = transform (fill (Inf ,length (vi[DynamicPPL. SampleFromPrior ()])), vi, model, constrained_space {true} ())
220
+
221
+ d = length (vi[:])
222
+ lb = transform (fill (- Inf , d), vi, model, constrained_space {true} ())
223
+ ub = transform (fill (Inf , d), vi, model, constrained_space {true} ())
231
224
232
225
return lb, ub
233
226
end
0 commit comments