Skip to content

Commit 80a54f1

Browse files
committed
Merge branch 'master' into step
2 parents 444cfd5 + 26d90d7 commit 80a54f1

File tree

7 files changed

+116
-88
lines changed

7 files changed

+116
-88
lines changed

src/compiler.jl

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ function generate_mainbody!(found, expr::Expr, args)
208208
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
209209
end
210210

211-
# """ Unbreak code highlighting in Emacs julia-mode
212211

213212

214213
"""
@@ -218,7 +217,7 @@ Generate an `observe` expression for data variables and `assume` expression for
218217
variables.
219218
"""
220219
function generate_tilde(left, right, args)
221-
@gensym tmpright tmpleft
220+
@gensym tmpright
222221
top = [:($tmpright = $right),
223222
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
224223
|| throw(ArgumentError($DISTMSG)))]
@@ -227,48 +226,32 @@ function generate_tilde(left, right, args)
227226
@gensym out vn inds
228227
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
229228

230-
assumption = [
231-
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
232-
_varinfo)),
233-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
234-
:($left = $out[1])
235-
]
236-
237229
# It can only be an observation if the LHS is an argument of the model
238230
if vsym(left) in args
239231
@gensym isassumption
240232
return quote
241233
$(top...)
242234
$isassumption = $(DynamicPPL.isassumption(left))
243235
if $isassumption
244-
$(assumption...)
236+
$left = $(DynamicPPL.tilde_assume)(
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
245238
else
246-
$tmpleft = $left
247-
$(DynamicPPL.acclogp!)(
248-
_varinfo,
249-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
250-
$vn, $inds, _varinfo)
251-
)
252-
$tmpleft
239+
$(DynamicPPL.tilde_observe)(
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
253241
end
254242
end
255243
end
256244

257245
return quote
258246
$(top...)
259-
$(assumption...)
247+
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
260248
end
261249
end
262250

263251
# If the LHS is a literal, it is always an observation
264252
return quote
265253
$(top...)
266-
$tmpleft = $left
267-
$(DynamicPPL.acclogp!)(
268-
_varinfo,
269-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
270-
)
271-
$tmpleft
254+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
272255
end
273256
end
274257

@@ -278,7 +261,7 @@ end
278261
Generate the expression that replaces `left .~ right` in the model body.
279262
"""
280263
function generate_dot_tilde(left, right, args)
281-
@gensym tmpright tmpleft
264+
@gensym tmpright
282265
top = [:($tmpright = $right),
283266
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
284267
|| throw(ArgumentError($DISTMSG)))]
@@ -287,49 +270,33 @@ function generate_dot_tilde(left, right, args)
287270
@gensym out vn inds
288271
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
289272

290-
assumption = [
291-
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
292-
$vn, $inds, _varinfo)),
293-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
294-
:($left .= $out[1])
295-
]
296-
297273
# It can only be an observation if the LHS is an argument of the model
298274
if vsym(left) in args
299275
@gensym isassumption
300276
return quote
301277
$(top...)
302278
$isassumption = $(DynamicPPL.isassumption(left))
303279
if $isassumption
304-
$(assumption...)
280+
$left .= $(DynamicPPL.dot_tilde_assume)(
281+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
305282
else
306-
$tmpleft = $left
307-
$(DynamicPPL.acclogp!)(
308-
_varinfo,
309-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
310-
$tmpleft, $vn, $inds, _varinfo)
311-
)
312-
$tmpleft
283+
$(DynamicPPL.dot_tilde_observe)(
284+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
313285
end
314286
end
315287
end
316288

317289
return quote
318290
$(top...)
319-
$(assumption...)
291+
$left .= $(DynamicPPL.dot_tilde_assume)(
292+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
320293
end
321294
end
322295

323296
# If the LHS is a literal, it is always an observation
324297
return quote
325298
$(top...)
326-
$tmpleft = $left
327-
$(DynamicPPL.acclogp!)(
328-
_varinfo,
329-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
330-
_varinfo)
331-
)
332-
$tmpleft
299+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
333300
end
334301
end
335302

src/context_implementations.jl

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ end
3838
"""
3939
tilde_assume(ctx, sampler, right, vn, inds, vi)
4040
41-
This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where
42-
`x` does not occur in the model inputs.
41+
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42+
accumulate the log probability, and return the sampled value.
4343
4444
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4545
"""
4646
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47-
return tilde(ctx, sampler, right, vn, inds, vi)
47+
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
48+
acclogp!(vi, logp)
49+
return value
4850
end
4951

5052

@@ -72,24 +74,30 @@ end
7274
"""
7375
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7476
75-
This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where
76-
`x` does occur in the model inputs.
77+
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
78+
accumulate the log probability, and return the observed value.
7779
78-
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
79-
name and indices; if needed, these can be accessed through this function, though.
80+
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
81+
and indices; if needed, these can be accessed through this function, though.
8082
"""
8183
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
82-
return tilde(ctx, sampler, right, left, vi)
84+
logp = tilde(ctx, sampler, right, left, vi)
85+
acclogp!(vi, logp)
86+
return left
8387
end
8488

8589
"""
8690
tilde_observe(ctx, sampler, right, left, vi)
8791
88-
This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`.
92+
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93+
observed value.
94+
8995
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9096
"""
9197
function tilde_observe(ctx, sampler, right, left, vi)
92-
return tilde(ctx, sampler, right, left, vi)
98+
logp = tilde(ctx, sampler, right, left, vi)
99+
acclogp!(vi, logp)
100+
return left
93101
end
94102

95103

@@ -103,24 +111,48 @@ function observe(spl::Sampler, weight)
103111
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
104112
end
105113

114+
# If parameters exist, they are used and not overwritten.
106115
function assume(
107-
spl::Union{SampleFromPrior, SampleFromUniform},
116+
spl::SampleFromPrior,
108117
dist::Distribution,
109118
vn::VarName,
110119
vi::VarInfo,
111120
)
112121
if haskey(vi, vn)
113122
if is_flagged(vi, vn, "del")
114123
unset_flag!(vi, vn, "del")
115-
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
124+
r = rand(dist)
116125
vi[vn] = vectorize(dist, r)
126+
settrans!(vi, false, vn)
117127
setorder!(vi, vn, get_num_produce(vi))
118128
else
119-
r = vi[vn]
129+
r = vi[vn]
120130
end
121131
else
122-
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
132+
r = rand(dist)
123133
push!(vi, vn, r, dist, spl)
134+
settrans!(vi, false, vn)
135+
end
136+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
137+
end
138+
139+
# Always overwrites the parameters with new ones.
140+
function assume(
141+
spl::SampleFromUniform,
142+
dist::Distribution,
143+
vn::VarName,
144+
vi::VarInfo,
145+
)
146+
if haskey(vi, vn)
147+
unset_flag!(vi, vn, "del")
148+
r = init(dist)
149+
vi[vn] = vectorize(dist, r)
150+
settrans!(vi, true, vn)
151+
setorder!(vi, vn, get_num_produce(vi))
152+
else
153+
r = init(dist)
154+
push!(vi, vn, r, dist, spl)
155+
settrans!(vi, true, vn)
124156
end
125157
# NOTE: The importance weight is not correctly computed here because
126158
# r is genereated from some uniform distribution which is different from the prior
@@ -191,13 +223,15 @@ end
191223
"""
192224
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
193225
194-
This method is applied in the generated code for assumed vectorized variables, e.g., `x .~
195-
MvNormal()` where `x` does not occur in the model inputs.
226+
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
227+
model inputs), accumulate the log probability, and return the sampled value.
196228
197229
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
198230
"""
199231
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200-
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
232+
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
233+
acclogp!(vi, logp)
234+
return value
201235
end
202236

203237

@@ -367,24 +401,30 @@ end
367401
"""
368402
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
369403
370-
This method is applied in the generated code for vectorized observed variables, e.g., `x .~
371-
MvNormal()` where `x` does occur the model inputs.
404+
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
405+
accumulate the log probability, and return the observed value.
372406
373407
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
374408
name and indices; if needed, these can be accessed through this function, though.
375409
"""
376410
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
377-
return dot_tilde(ctx, sampler, right, left, vi)
411+
logp = dot_tilde(ctx, sampler, right, left, vi)
412+
acclogp!(vi, logp)
413+
return left
378414
end
379415

380416
"""
381417
dot_tilde_observe(ctx, sampler, right, left, vi)
382418
383-
This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~
384-
MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
419+
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
420+
probability, and return the observed value.
421+
422+
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
385423
"""
386424
function dot_tilde_observe(ctx, sampler, right, left, vi)
387-
return dot_tilde(ctx, sampler, right, left, vi)
425+
logp = dot_tilde(ctx, sampler, right, left, vi)
426+
acclogp!(vi, logp)
427+
return left
388428
end
389429

390430

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ end
117117

118118
# ROBUST INITIALISATIONS
119119
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
120-
randrealuni() = Real(2rand())
121-
randrealuni(args...) = map(Real, 2rand(args...))
120+
randrealuni() = 4 * rand() - 2
121+
randrealuni(args...) = 4 .* rand(args...) .- 2
122122

123123
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
124124

src/varinfo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ end
420420

421421
# Get all vns of variables belonging to spl
422422
_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl)))
423+
_getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = _getvns(vi, Selector(), Val(()))
423424
_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space))
424425
function _getvns(vi::TypedVarInfo, s::Selector, space)
425426
return _getvns(vi.metadata, _getidcs(vi, s, space))
@@ -820,6 +821,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`.
820821
The value(s) may or may not be transformed to Euclidean space.
821822
"""
822823
getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi))
824+
getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi))
823825
getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))
824826
function getindex(vi::TypedVarInfo, spl::Sampler)
825827
# Gets the ranges as a NamedTuple

src/varname.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ _issubrange(i::Colon, j::ConcreteIndex) = true
139139
140140
A macro that returns an instance of `VarName` given the symbol or expression of a Julia variable,
141141
e.g. `@varname x[1,2][1+5][45][3]` returns `VarName{:x}(((1, 2), (6,), (45,), (3,)))`.
142+
143+
!!! compat "Julia 1.5"
144+
Using `begin` in an indexing expression to refer to the first index requires at least
145+
Julia 1.5.
142146
"""
143147
macro varname(expr::Union{Expr, Symbol})
144148
return esc(varname(expr))
@@ -177,8 +181,12 @@ end
177181
"""
178182
@vinds(expr)
179183
180-
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1, :][2]` returns
184+
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1, :][2]` returns
181185
`((1, Colon()), (2,))`.
186+
187+
!!! compat "Julia 1.5"
188+
Using `begin` in an indexing expression to refer to the first index requires at least
189+
Julia 1.5.
182190
"""
183191
macro vinds(expr::Union{Expr, Symbol})
184192
return esc(vinds(expr))
@@ -188,7 +196,11 @@ vinds(expr::Symbol) = Expr(:tuple)
188196
function vinds(expr::Expr)
189197
if Meta.isexpr(expr, :ref)
190198
ex = copy(expr)
191-
Base.replace_ref_end!(ex)
199+
@static if VERSION < v"1.5.0-DEV.666"
200+
Base.replace_ref_end!(ex)
201+
else
202+
Base.replace_ref_begin_end!(ex)
203+
end
192204
last = Expr(:tuple, ex.args[2:end]...)
193205
init = vinds(ex.args[1]).args
194206
return Expr(:tuple, init..., last)
@@ -197,7 +209,6 @@ function vinds(expr::Expr)
197209
end
198210
end
199211

200-
201212
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
202213
return s in argnames
203214
end

test/compiler.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,14 @@ end
543543
var_expr = :(x[2:3,2:3][[1,2],[1,2]])
544544
@test vsym(var_expr) == :x
545545
@test vinds(var_expr) == :(((2:3, 2:3), ([1, 2], [1, 2])))
546+
547+
var_expr = :(x[end])
548+
@test vsym(var_expr) == :x
549+
@test vinds(var_expr) == :((($lastindex(x),),))
550+
551+
var_expr = :(x[1, end])
552+
@test vsym(var_expr) == :x
553+
@test vinds(var_expr) == :(((1, $lastindex(x, 2)),))
546554
end
547555
@testset "user-defined variable name" begin
548556
@model f1() = begin

0 commit comments

Comments
 (0)