Skip to content

Commit e50d305

Browse files
committed
Use ParamsInit for predict; remove setval_and_resample! and friends
1 parent fbcb82b commit e50d305

File tree

7 files changed

+49
-216
lines changed

7 files changed

+49
-216
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function _check_varname_indexing(c::MCMCChains.Chains)
3030
return DynamicPPL.supports_varname_indexing(c) ||
31-
error("Chains do not support indexing using `VarName`s.")
31+
error("This `Chains` object does not support indexing using `VarName`s.")
3232
end
3333

3434
function DynamicPPL.getindex_varname(
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{DynamicPPL.VarName,Any}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,15 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
128+
# Resample any variables that are not present in `values_dict`
129+
_, varinfo = DynamicPPL.init!!(
130+
rng,
131+
model,
132+
varinfo,
133+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
134+
)
120135
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121136
varname_vals = mapreduce(
122137
collect,
@@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248263
varinfo = DynamicPPL.VarInfo(model)
249264
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250265
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
266+
# Extract values from the chain
267+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
268+
# Resample any variables that are not present in `values_dict`, and
269+
# return the model's retval.
270+
retval, _ = DynamicPPL.init!!(
271+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
272+
)
273+
retval
258274
end
259275
end
260276

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,8 +1209,15 @@ function predict(
12091209
varinfo = DynamicPPL.VarInfo(model)
12101210
return map(chain) do params_varinfo
12111211
vi = deepcopy(varinfo)
1212-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1213-
model(rng, vi)
1212+
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1213+
# parameters and one to set them. The reason why we need values_as_in_model
1214+
# is because `params_varinfo` may well have some weird combination of
1215+
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1216+
# freshly constructed.
1217+
# This is quite inefficient. It would of course be alright if
1218+
# ValuesAsInModelAccumulator was a default acc.
1219+
values_nt = values_as_in_model(model, false, params_varinfo)
1220+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
12141221
return vi
12151222
end
12161223
end

src/varinfo.jl

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,42 +1514,6 @@ function islinked(vi::VarInfo)
15141514
return any(istrans(vi, vn) for vn in keys(vi))
15151515
end
15161516

1517-
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
1518-
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1519-
end
1520-
function nested_setindex_maybe!(
1521-
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
1522-
) where {names,sym}
1523-
return if sym in names
1524-
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1525-
else
1526-
nothing
1527-
end
1528-
end
1529-
function _nested_setindex_maybe!(
1530-
vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName
1531-
)
1532-
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
1533-
vns = Base.keys(md)
1534-
if vn in vns
1535-
setindex!(vi, val, vn)
1536-
return vn
1537-
end
1538-
1539-
# Otherwise, we need to check if either of the `vns` subsumes `vn`.
1540-
i = findfirst(Base.Fix2(subsumes, vn), vns)
1541-
i === nothing && return nothing
1542-
1543-
vn_parent = vns[i]
1544-
val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here.
1545-
# Split the varname into its tail optic.
1546-
optic = remove_parent_optic(vn_parent, vn)
1547-
# Update the value for the parent.
1548-
val_parent_updated = set!!(val_parent, optic, val)
1549-
setindex!(vi, val_parent_updated, vn_parent)
1550-
return vn_parent
1551-
end
1552-
15531517
# The default getindex & setindex!() for get & set values
15541518
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
15551519
function getindex(vi::VarInfo, vn::VarName)
@@ -1994,113 +1958,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
19941958
return indices
19951959
end
19961960

1997-
"""
1998-
setval_and_resample!(vi::VarInfo, x)
1999-
setval_and_resample!(vi::VarInfo, values, keys)
2000-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
2001-
2002-
Set the values in `vi` to the provided values and those which are not present
2003-
in `x` or `chains` to *be* resampled.
2004-
2005-
Note that this does *not* resample the values not provided! It will call
2006-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
2007-
that the next time we call `model(vi)` these variables will be resampled.
2008-
2009-
## Note
2010-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
2011-
2012-
## Example
2013-
```jldoctest
2014-
julia> using DynamicPPL, Distributions, StableRNGs
2015-
2016-
julia> @model function demo(x)
2017-
m ~ Normal()
2018-
for i in eachindex(x)
2019-
x[i] ~ Normal(m, 1)
2020-
end
2021-
end;
2022-
2023-
julia> rng = StableRNG(42);
2024-
2025-
julia> m = demo([missing]);
2026-
2027-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2028-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2029-
2030-
julia> var_info[@varname(m)]
2031-
-0.6702516921145671
2032-
2033-
julia> var_info[@varname(x[1])]
2034-
-0.22312984965118443
2035-
2036-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2037-
2038-
julia> var_info[@varname(m)] # [✓] changed
2039-
100.0
2040-
2041-
julia> var_info[@varname(x[1])] # [✓] unchanged
2042-
-0.22312984965118443
2043-
2044-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2045-
2046-
julia> var_info[@varname(m)] # [✓] unchanged
2047-
100.0
2048-
2049-
julia> var_info[@varname(x[1])] # [✓] changed
2050-
101.37363069798343
2051-
```
2052-
2053-
## See also
2054-
- [`setval!`](@ref)
2055-
"""
2056-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2057-
return setval_and_resample!(vi, values(x), keys(x))
2058-
end
2059-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2060-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2061-
end
2062-
function setval_and_resample!(
2063-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2064-
)
2065-
if supports_varname_indexing(chains)
2066-
# First we need to set every variable to be resampled.
2067-
for vn in keys(vi)
2068-
set_flag!(vi, vn, "del")
2069-
end
2070-
# Then we set the variables in `varinfo` from `chain`.
2071-
for vn in varnames(chains)
2072-
vn_updated = nested_setindex_maybe!(
2073-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2074-
)
2075-
2076-
# Unset the `del` flag if we found something.
2077-
if vn_updated !== nothing
2078-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2079-
unset_flag!(vi, vn_updated, "del")
2080-
end
2081-
end
2082-
else
2083-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2084-
end
2085-
end
2086-
2087-
function _setval_and_resample_kernel!(
2088-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2089-
)
2090-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2091-
if !isempty(indices)
2092-
val = reduce(vcat, values[indices])
2093-
setval!(vi, val, vn)
2094-
settrans!!(vi, false, vn)
2095-
else
2096-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2097-
# the model on `vi` again.
2098-
set_flag!(vi, vn, "del")
2099-
end
2100-
2101-
return indices
2102-
end
2103-
21041961
values_as(vi::VarInfo) = vi.metadata
21051962
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
21061963
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
@model demo() = x ~ Normal()
33
model = demo()
44

5-
chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y]))
5+
chain = MCMCChains.Chains(
6+
randn(1000, 2, 1),
7+
[:x, :y],
8+
Dict(:internals => [:y]);
9+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
10+
)
611
chain_generated = @test_nowarn returned(model, chain)
712
@test size(chain_generated) == (1000, 1)
813
@test mean(chain_generated) 0 atol = 0.1

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
580580
xs_train = 1:0.1:10
581581
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
582582
m_lin_reg = linear_reg(xs_train, ys_train)
583-
chain = [VarInfo(m_lin_reg) _ in 1:10000]
583+
chain = [VarInfo(m_lin_reg) for _ in 1:10000]
584584

585585
# chain is generated from the prior
586586
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8181
varnames = collect(varnames)
8282
# Construct matrix of values
8383
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
84+
# Construct dict of varnames -> symbol
85+
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
8486
# Construct and return the Chains object
85-
return Chains(vals, varnames)
87+
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
8688
end
8789
function make_chain_from_prior(model::Model, n_iters::Int)
8890
return make_chain_from_prior(Random.default_rng(), model, n_iters)

test/varinfo.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ end
283283
@test typed_vi[vn_y] == 2.0
284284
end
285285

286-
@testset "setval! & setval_and_resample!" begin
286+
@testset "setval!" begin
287287
@model function testmodel(x)
288288
n = length(x)
289289
s ~ truncated(Normal(); lower=0)
@@ -334,8 +334,8 @@ end
334334
else
335335
DynamicPPL.setval!(vicopy, (m=zeros(5),))
336336
end
337-
# Setting `m` fails for univariate due to limitations of `setval!`
338-
# and `setval_and_resample!`. See docstring of `setval!` for more info.
337+
# Setting `m` fails for univariate due to limitations of `setval!`.
338+
# See docstring of `setval!` for more info.
339339
if model == model_uv && vi in [vi_untyped, vi_typed]
340340
@test_broken vicopy[m_vns] == zeros(5)
341341
else
@@ -360,57 +360,6 @@ end
360360
DynamicPPL.setval!(vicopy, (s=42,))
361361
@test vicopy[m_vns] == 1:5
362362
@test vicopy[s_vns] == 42
363-
364-
### `setval_and_resample!` ###
365-
if model == model_mv && vi == vi_untyped
366-
# Trying to re-run model with `MvNormal` on `vi_untyped` will call
367-
# `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError`
368-
# so we skip this particular case.
369-
continue
370-
end
371-
372-
if vi in [vi_vnv, vi_vnv_typed]
373-
# `setval_and_resample!` works differently for `VarNamedVector`: All
374-
# values will be resampled when model(vicopy) is called. Hence the below
375-
# tests are not applicable.
376-
continue
377-
end
378-
379-
vicopy = deepcopy(vi)
380-
DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),))
381-
model(vicopy)
382-
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
383-
if model == model_uv
384-
@test_broken vicopy[m_vns] == zeros(5)
385-
else
386-
@test vicopy[m_vns] == zeros(5)
387-
end
388-
@test vicopy[s_vns] != vi[s_vns]
389-
390-
# Ordering is NOT preserved.
391-
DynamicPPL.setval_and_resample!(
392-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
393-
)
394-
model(vicopy)
395-
if model == model_uv
396-
@test vicopy[m_vns] == 1:5
397-
else
398-
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
399-
end
400-
@test vicopy[s_vns] != vi[s_vns]
401-
402-
# Correct ordering.
403-
DynamicPPL.setval_and_resample!(
404-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
405-
)
406-
model(vicopy)
407-
@test vicopy[m_vns] == 1:5
408-
@test vicopy[s_vns] != vi[s_vns]
409-
410-
DynamicPPL.setval_and_resample!(vicopy, (s=42,))
411-
model(vicopy)
412-
@test vicopy[m_vns] != 1:5
413-
@test vicopy[s_vns] == 42
414363
end
415364
end
416365

@@ -424,9 +373,6 @@ end
424373
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
425374
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
426375
@test vals_prev == vi.metadata.x.vals
427-
428-
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
429-
@test vals_prev == vi.metadata.x.vals
430376
end
431377

432378
@testset "setval! on chain" begin

0 commit comments

Comments
 (0)