Skip to content

Commit b9724bc

Browse files
committed
Some more fixes
1 parent 3a09b52 commit b9724bc

File tree

6 files changed

+34
-25
lines changed

6 files changed

+34
-25
lines changed

src/threadsafe.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
5555
getmetadata(vi::ThreadSafeVarInfo, vn::VarName) = getmetadata(vi.varinfo, vn)
5656
getidx(vi::ThreadSafeVarInfo, vn::VarName) = getidx(vi.varinfo, vn)
5757
getrange(vi::ThreadSafeVarInfo, vn::VarName) = getrange(vi.varinfo, vn)
58+
getdist(vi::ThreadSafeVarInfo, vn::VarName) = getdist(vi.varinfo, vn)
59+
60+
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
61+
setgid!(vi.varinfo, gid, vn)
62+
end
5863

5964
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
6065
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)

test/Turing/inference/AdvancedSMC.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ function AbstractMCMC.sample_end!(
260260
spl.state.average_logevidence = loge
261261
end
262262

263-
function DynamicPPL.assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, ::VarInfo)
263+
function DynamicPPL.assume(
264+
spl::Sampler{<:Union{PG,SMC}},
265+
dist::Distribution,
266+
vn::VarName,
267+
vi
268+
)
264269
vi = current_trace().vi
265270
if DynamicPPL.inspace(vn, getspace(spl))
266271
if ~haskey(vi, vn)

test/Turing/inference/hmc.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,25 +391,25 @@ end
391391
#####
392392

393393
"""
394-
gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)
394+
gen_∂logπ∂θ(vi, spl::Sampler, model)
395395
396396
Generate a function that takes a vector of reals `θ` and compute the logpdf and
397397
gradient at `θ` for the model specified by `(vi, spl, model)`.
398398
"""
399-
function gen_∂logπ∂θ(vi::VarInfo, spl::Sampler, model)
399+
function gen_∂logπ∂θ(vi, spl::Sampler, model)
400400
function ∂logπ∂θ(x)
401401
return gradient_logp(x, vi, model, spl)
402402
end
403403
return ∂logπ∂θ
404404
end
405405

406406
"""
407-
gen_logπ(vi::VarInfo, spl::Sampler, model)
407+
gen_logπ(vi, spl::Sampler, model)
408408
409409
Generate a function that takes `θ` and returns logpdf at `θ` for the model specified by
410410
`(vi, spl, model)`.
411411
"""
412-
function gen_logπ(vi::VarInfo, spl::Sampler, model)
412+
function gen_logπ(vi, spl::Sampler, model)
413413
function logπ(x)::Float64
414414
x_old, lj_old = vi[spl], getlogp(vi)
415415
vi[spl] = x
@@ -437,7 +437,7 @@ function DynamicPPL.assume(
437437
spl::Sampler{<:Hamiltonian},
438438
dist::Distribution,
439439
vn::VarName,
440-
vi::VarInfo
440+
vi,
441441
)
442442
Turing.DEBUG && _debug("assuming...")
443443
updategid!(vi, vn, spl)
@@ -455,7 +455,7 @@ function DynamicPPL.dot_assume(
455455
dist::MultivariateDistribution,
456456
vns::AbstractArray{<:VarName},
457457
var::AbstractMatrix,
458-
vi::VarInfo,
458+
vi,
459459
)
460460
@assert length(dist) == size(var, 1)
461461
updategid!.(Ref(vi), vns, Ref(spl))
@@ -468,7 +468,7 @@ function DynamicPPL.dot_assume(
468468
dists::Union{Distribution, AbstractArray{<:Distribution}},
469469
vns::AbstractArray{<:VarName},
470470
var::AbstractArray,
471-
vi::VarInfo,
471+
vi,
472472
)
473473
updategid!.(Ref(vi), vns, Ref(spl))
474474
r = reshape(vi[vec(vns)], size(var))
@@ -480,7 +480,7 @@ function DynamicPPL.observe(
480480
spl::Sampler{<:Hamiltonian},
481481
d::Distribution,
482482
value,
483-
vi::VarInfo,
483+
vi,
484484
)
485485
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
486486
end
@@ -489,7 +489,7 @@ function DynamicPPL.dot_observe(
489489
spl::Sampler{<:Hamiltonian},
490490
ds::Union{Distribution, AbstractArray{<:Distribution}},
491491
value::AbstractArray,
492-
vi::VarInfo,
492+
vi,
493493
)
494494
return DynamicPPL.dot_observe(SampleFromPrior(), ds, value, vi)
495495
end

test/Turing/inference/is.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ function AbstractMCMC.sample_end!(
7070
spl.state.final_logevidence = logsumexp(map(x->x.lp, ts)) - log(N)
7171
end
7272

73-
function DynamicPPL.assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi::VarInfo)
73+
function DynamicPPL.assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
7474
r = rand(dist)
7575
push!(vi, vn, r, dist, spl)
7676
return r, 0
7777
end
7878

79-
function DynamicPPL.observe(spl::Sampler{<:IS}, dist::Distribution, value, vi::VarInfo)
79+
function DynamicPPL.observe(spl::Sampler{<:IS}, dist::Distribution, value, vi)
8080
# acclogp!(vi, logpdf(dist, value))
8181
return logpdf(dist, value)
8282
end

test/Turing/inference/mh.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ function set_namedtuple!(vi::VarInfo, nt::NamedTuple)
9494
end
9595

9696
"""
97-
gen_logπ_mh(vi::VarInfo, spl::Sampler, model)
97+
gen_logπ_mh(vi, spl::Sampler, model)
9898
99-
Generate a log density function -- this variant uses the
99+
Generate a log density function -- this variant uses the
100100
`set_namedtuple!` function to update the `VarInfo`.
101101
"""
102102
function gen_logπ_mh(spl::Sampler, model)
@@ -114,7 +114,7 @@ function gen_logπ_mh(spl::Sampler, model)
114114
return logπ
115115
end
116116

117-
function scalar_map(vi::VarInfo, vns::Vector{V}) where V<:VarName
117+
function scalar_map(vi, vns::Vector{V}) where V<:VarName
118118
vals = getindex(vi, vns)
119119
if length(vals) == length(vns) == 1
120120
# It's a scalar!
@@ -138,7 +138,7 @@ function dist_val_tuple(spl::Sampler{<:MH})
138138
return dt, vt
139139
end
140140

141-
@generated function _val_tuple(metadata::NamedTuple, vi::VarInfo, vns::NamedTuple{names}) where {names}
141+
@generated function _val_tuple(metadata::NamedTuple, vi, vns::NamedTuple{names}) where {names}
142142
length(names) === 0 && return :(NamedTuple())
143143
expr = Expr(:tuple)
144144
map(names) do f
@@ -150,7 +150,7 @@ end
150150
@generated function _dist_tuple(
151151
metadata::NamedTuple,
152152
props::NamedTuple{propnames},
153-
vi::VarInfo,
153+
vi,
154154
vns::NamedTuple{names}
155155
) where {names, propnames}
156156
length(names) === 0 && return :(NamedTuple())
@@ -258,7 +258,7 @@ function DynamicPPL.assume(
258258
spl::Sampler{<:MH},
259259
dist::Distribution,
260260
vn::VarName,
261-
vi::VarInfo
261+
vi,
262262
)
263263
updategid!(vi, vn, spl)
264264
r = vi[vn]
@@ -270,7 +270,7 @@ function DynamicPPL.dot_assume(
270270
dist::MultivariateDistribution,
271271
vn::VarName,
272272
var::AbstractMatrix,
273-
vi::VarInfo,
273+
vi,
274274
)
275275
@assert dim(dist) == size(var, 1)
276276
getvn = i -> VarName(vn, vn.indexing * "[:,$i]")
@@ -285,7 +285,7 @@ function DynamicPPL.dot_assume(
285285
dists::Union{Distribution, AbstractArray{<:Distribution}},
286286
vn::VarName,
287287
var::AbstractArray,
288-
vi::VarInfo,
288+
vi,
289289
)
290290
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
291291
vns = getvn.(CartesianIndices(var))
@@ -299,7 +299,7 @@ function DynamicPPL.observe(
299299
spl::Sampler{<:MH},
300300
d::Distribution,
301301
value,
302-
vi::VarInfo,
302+
vi,
303303
)
304304
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
305305
end
@@ -308,7 +308,7 @@ function DynamicPPL.dot_observe(
308308
spl::Sampler{<:MH},
309309
ds::Union{Distribution, AbstractArray{<:Distribution}},
310310
value::AbstractArray,
311-
vi::VarInfo,
311+
vi,
312312
)
313313
return DynamicPPL.dot_observe(SampleFromPrior(), ds, value, vi)
314314
end

test/compiler.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ end
215215
global sampler_ = _sampler
216216
global model_ = _model
217217
global context_ = _context
218-
global logps_ = _logps
219-
global lp = sum(_logps)
218+
global lp = getlogp(_varinfo)
220219
return x
221220
end
222221
model = testmodel([1.0])
@@ -250,7 +249,7 @@ end
250249
function makemodel(p)
251250
@model testmodel(x) = begin
252251
x[1] ~ Bernoulli(p)
253-
global lp = sum(_logps)
252+
global lp = getlogp(_varinfo)
254253
return x
255254
end
256255
return testmodel

0 commit comments

Comments
 (0)