Skip to content

Commit c0261f4

Browse files
committed
Add fallback functions for model evaluation from Turing code
1 parent 16fceea commit c0261f4

File tree

3 files changed

+384
-3
lines changed

3 files changed

+384
-3
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ export VarName,
8282

8383
# Used here and overloaded in Turing
8484
function getspace end
85-
function tilde end
86-
function dot_tilde end
8785

8886
include("utils.jl")
8987
include("selector.jl")
@@ -93,6 +91,7 @@ include("varname.jl")
9391
include("distribution_wrappers.jl")
9492
include("contexts.jl")
9593
include("varinfo.jl")
94+
include("context_implementations.jl")
9695
include("compiler.jl")
9796
include("prob_macro.jl")
9897

src/context_implementations.jl

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
2+
3+
# utility funcs for querying sampler information
4+
require_gradient(spl::Sampler) = false
5+
require_particles(spl::Sampler) = false
6+
7+
_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
8+
_getindex(x, inds::Tuple{}) = x
9+
10+
# assume
11+
function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
12+
return _tilde(sampler, right, vn, vi)
13+
end
14+
function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
15+
if ctx.vars !== nothing
16+
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
17+
settrans!(vi, false, vn)
18+
end
19+
return _tilde(sampler, right, vn, vi)
20+
end
21+
function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
22+
if ctx.vars !== nothing
23+
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
24+
settrans!(vi, false, vn)
25+
end
26+
return _tilde(sampler, NoDist(right), vn, vi)
27+
end
28+
function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
29+
return tilde(ctx.ctx, sampler, right, left, inds, vi)
30+
end
31+
32+
function _tilde(sampler, right, vn::VarName, vi)
33+
return Turing.assume(sampler, right, vn, vi)
34+
end
35+
function _tilde(sampler, right::NamedDist, vn::VarName, vi)
36+
name = right.name
37+
if name isa String
38+
sym_str, inds = split_var_str(name, String)
39+
sym = Symbol(sym_str)
40+
vn = VarName{sym}(inds)
41+
elseif name isa Symbol
42+
vn = VarName{name}("")
43+
elseif name isa VarName
44+
vn = name
45+
else
46+
throw("Unsupported variable name. Please use either a string, symbol or VarName.")
47+
end
48+
return _tilde(sampler, right.dist, vn, vi)
49+
end
50+
51+
# observe
52+
function tilde(ctx::DefaultContext, sampler, right, left, vi)
53+
return _tilde(sampler, right, left, vi)
54+
end
55+
function tilde(ctx::PriorContext, sampler, right, left, vi)
56+
return 0
57+
end
58+
function tilde(ctx::LikelihoodContext, sampler, right, left, vi)
59+
return _tilde(sampler, right, left, vi)
60+
end
61+
function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
62+
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
63+
end
64+
65+
_tilde(sampler, right, left, vi) = Turing.observe(sampler, right, left, vi)
66+
67+
function assume(spl::Sampler, dist)
68+
error("Turing.assume: unmanaged inference algorithm: $(typeof(spl))")
69+
end
70+
71+
function observe(spl::Sampler, weight)
72+
error("Turing.observe: unmanaged inference algorithm: $(typeof(spl))")
73+
end
74+
75+
function assume(
76+
spl::Union{SampleFromPrior, SampleFromUniform},
77+
dist::Distribution,
78+
vn::VarName,
79+
vi::VarInfo,
80+
)
81+
if haskey(vi, vn)
82+
if is_flagged(vi, vn, "del")
83+
unset_flag!(vi, vn, "del")
84+
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
85+
vi[vn] = vectorize(dist, r)
86+
setorder!(vi, vn, get_num_produce(vi))
87+
else
88+
r = vi[vn]
89+
end
90+
else
91+
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
92+
push!(vi, vn, r, dist, spl)
93+
end
94+
# NOTE: The importance weight is not correctly computed here because
95+
# r is genereated from some uniform distribution which is different from the prior
96+
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
97+
98+
return r, logpdf_with_trans(dist, r, istrans(vi, vn))
99+
end
100+
101+
function observe(
102+
spl::Union{SampleFromPrior, SampleFromUniform},
103+
dist::Distribution,
104+
value,
105+
vi::VarInfo,
106+
)
107+
increment_num_produce!(vi)
108+
return logpdf(dist, value)
109+
end
110+
111+
# .~ functions
112+
113+
# assume
114+
function dot_tilde(ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi)
115+
vns, dist = get_vns_and_dist(right, left, vn)
116+
return _dot_tilde(sampler, dist, left, vns, vi)
117+
end
118+
function dot_tilde(
119+
ctx::LikelihoodContext,
120+
sampler,
121+
right,
122+
left,
123+
vn::VarName,
124+
inds,
125+
vi,
126+
)
127+
if ctx.vars !== nothing
128+
var = _getindex(getfield(ctx.vars, getsym(vn)), inds)
129+
vns, dist = get_vns_and_dist(right, var, vn)
130+
set_val!(vi, vns, dist, var)
131+
settrans!.(Ref(vi), false, vns)
132+
else
133+
vns, dist = get_vns_and_dist(right, left, vn)
134+
end
135+
return _dot_tilde(sampler, NoDist(dist), left, vns, vi)
136+
end
137+
function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi)
138+
return dot_tilde(ctx.ctx, sampler, right, left, vn, inds, vi)
139+
end
140+
function dot_tilde(
141+
ctx::PriorContext,
142+
sampler,
143+
right,
144+
left,
145+
vn::VarName,
146+
inds,
147+
vi,
148+
)
149+
if ctx.vars !== nothing
150+
var = _getindex(getfield(ctx.vars, getsym(vn)), inds)
151+
vns, dist = get_vns_and_dist(right, var, vn)
152+
set_val!(vi, vns, dist, var)
153+
settrans!.(Ref(vi), false, vns)
154+
else
155+
vns, dist = get_vns_and_dist(right, left, vn)
156+
end
157+
return _dot_tilde(sampler, dist, left, vns, vi)
158+
end
159+
160+
function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
161+
name = dist.name
162+
if name isa String
163+
sym_str, inds = split_var_str(name, String)
164+
sym = Symbol(sym_str)
165+
vn = VarName{sym}(inds)
166+
elseif name isa Symbol
167+
vn = VarName{name}("")
168+
elseif name isa VarName
169+
vn = name
170+
else
171+
throw("Unsupported variable name. Please use either a string, symbol or VarName.")
172+
end
173+
return get_vns_and_dist(dist.dist, var, vn)
174+
end
175+
function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName)
176+
getvn = i -> VarName(vn, vn.indexing * "[Colon(),$i]")
177+
return getvn.(1:size(var, 2)), dist
178+
end
179+
function get_vns_and_dist(
180+
dist::Union{Distribution, AbstractArray{<:Distribution}},
181+
var::AbstractArray,
182+
vn::VarName
183+
)
184+
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
185+
return getvn.(CartesianIndices(var)), dist
186+
end
187+
188+
function _dot_tilde(sampler, right, left, vns::AbstractArray{<:VarName}, vi)
189+
return dot_assume(sampler, right, vns, left, vi)
190+
end
191+
192+
# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
193+
function _dot_tilde(
194+
sampler::AbstractSampler,
195+
right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}},
196+
left::AbstractMatrix{>:AbstractVector},
197+
vn::AbstractVector{<:VarName},
198+
vi::VarInfo,
199+
)
200+
throw(ambiguity_error_msg())
201+
end
202+
203+
function dot_assume(
204+
spl::Union{SampleFromPrior, SampleFromUniform},
205+
dist::MultivariateDistribution,
206+
vns::AbstractVector{<:VarName},
207+
var::AbstractMatrix,
208+
vi::VarInfo,
209+
)
210+
@assert length(dist) == size(var, 1)
211+
r = get_and_set_val!(vi, vns, dist, spl)
212+
lp = sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
213+
var .= r
214+
return var, lp
215+
end
216+
function dot_assume(
217+
spl::Union{SampleFromPrior, SampleFromUniform},
218+
dists::Union{Distribution, AbstractArray{<:Distribution}},
219+
vns::AbstractArray{<:VarName},
220+
var::AbstractArray,
221+
vi::VarInfo,
222+
)
223+
r = get_and_set_val!(vi, vns, dists, spl)
224+
# Make sure `r` is not a matrix for multivariate distributions
225+
lp = sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
226+
var .= r
227+
return var, lp
228+
end
229+
function dot_assume(
230+
spl::Sampler,
231+
::Any,
232+
::AbstractArray{<:VarName},
233+
::Any,
234+
::VarInfo
235+
)
236+
error("[Turing] $(alg_str(spl)) doesn't support vectorizing assume statement")
237+
end
238+
239+
function get_and_set_val!(
240+
vi::VarInfo,
241+
vns::AbstractVector{<:VarName},
242+
dist::MultivariateDistribution,
243+
spl::AbstractSampler,
244+
)
245+
n = length(vns)
246+
if haskey(vi, vns[1])
247+
if is_flagged(vi, vns[1], "del")
248+
unset_flag!(vi, vns[1], "del")
249+
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
250+
for i in 1:n
251+
vn = vns[i]
252+
vi[vn] = vectorize(dist, r[:, i])
253+
setorder!(vi, vn, get_num_produce(vi))
254+
end
255+
else
256+
r = vi[vns]
257+
end
258+
else
259+
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
260+
for i in 1:n
261+
push!(vi, vns[i], r[:,i], dist, spl)
262+
end
263+
end
264+
return r
265+
end
266+
function get_and_set_val!(
267+
vi::VarInfo,
268+
vns::AbstractArray{<:VarName},
269+
dists::Union{Distribution, AbstractArray{<:Distribution}},
270+
spl::AbstractSampler,
271+
)
272+
if haskey(vi, vns[1])
273+
if is_flagged(vi, vns[1], "del")
274+
unset_flag!(vi, vns[1], "del")
275+
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
276+
r = f.(vns, dists)
277+
for i in eachindex(vns)
278+
vn = vns[i]
279+
dist = dists isa AbstractArray ? dists[i] : dists
280+
vi[vn] = vectorize(dist, r[i])
281+
setorder!(vi, vn, get_num_produce(vi))
282+
end
283+
else
284+
r = reshape(vi[vec(vns)], size(vns))
285+
end
286+
else
287+
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
288+
r = f.(vns, dists)
289+
push!.(Ref(vi), vns, r, dists, Ref(spl))
290+
end
291+
return r
292+
end
293+
294+
function set_val!(
295+
vi::VarInfo,
296+
vns::AbstractVector{<:VarName},
297+
dist::MultivariateDistribution,
298+
val::AbstractMatrix,
299+
)
300+
@assert size(val, 2) == length(vns)
301+
foreach(enumerate(vns)) do (i, vn)
302+
vi[vn] = val[:,i]
303+
end
304+
return val
305+
end
306+
function set_val!(
307+
vi::VarInfo,
308+
vns::AbstractArray{<:VarName},
309+
dists::Union{Distribution, AbstractArray{<:Distribution}},
310+
val::AbstractArray,
311+
)
312+
@assert size(val) == size(vns)
313+
foreach(CartesianIndices(val)) do ind
314+
dist = dists isa AbstractArray ? dists[ind] : dists
315+
vi[vns[ind]] = vectorize(dist, val[ind])
316+
end
317+
return val
318+
end
319+
320+
# observe
321+
function dot_tilde(ctx::DefaultContext, sampler, right, left, vi)
322+
return _dot_tilde(sampler, right, left, vi)
323+
end
324+
function dot_tilde(ctx::PriorContext, sampler, right, left, vi)
325+
return 0
326+
end
327+
function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi)
328+
return _dot_tilde(sampler, right, left, vi)
329+
end
330+
function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
331+
return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, left, vi)
332+
end
333+
334+
function _dot_tilde(sampler, right, left::AbstractArray, vi)
335+
return dot_observe(sampler, right, left, vi)
336+
end
337+
# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
338+
function _dot_tilde(
339+
sampler::AbstractSampler,
340+
right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}},
341+
left::AbstractMatrix{>:AbstractVector},
342+
vi::VarInfo,
343+
)
344+
throw(ambiguity_error_msg())
345+
end
346+
347+
function dot_observe(
348+
spl::Union{SampleFromPrior, SampleFromUniform},
349+
dist::MultivariateDistribution,
350+
value::AbstractMatrix,
351+
vi::VarInfo,
352+
)
353+
increment_num_produce!(vi)
354+
Turing.DEBUG && @debug "dist = $dist"
355+
Turing.DEBUG && @debug "value = $value"
356+
return sum(logpdf(dist, value))
357+
end
358+
function dot_observe(
359+
spl::Union{SampleFromPrior, SampleFromUniform},
360+
dists::Union{Distribution, AbstractArray{<:Distribution}},
361+
value::AbstractArray,
362+
vi::VarInfo,
363+
)
364+
increment_num_produce!(vi)
365+
Turing.DEBUG && @debug "dists = $dists"
366+
Turing.DEBUG && @debug "value = $value"
367+
return sum(logpdf.(dists, value))
368+
end
369+
function dot_observe(
370+
spl::Sampler,
371+
::Any,
372+
::Any,
373+
::VarInfo,
374+
)
375+
error("[Turing] $(alg_str(spl)) doesn't support vectorizing observe statement")
376+
end

0 commit comments

Comments
 (0)