Skip to content

Commit b24fc96

Browse files
committed
Use ParamsInit for predict; remove setval_and_resample! and friends
1 parent 5720b62 commit b24fc96

File tree

7 files changed

+49
-216
lines changed

7 files changed

+49
-216
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function _check_varname_indexing(c::MCMCChains.Chains)
3030
return DynamicPPL.supports_varname_indexing(c) ||
31-
error("Chains do not support indexing using `VarName`s.")
31+
error("This `Chains` object does not support indexing using `VarName`s.")
3232
end
3333

3434
function DynamicPPL.getindex_varname(
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{DynamicPPL.VarName,Any}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,15 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
128+
# Resample any variables that are not present in `values_dict`
129+
_, varinfo = DynamicPPL.init!!(
130+
rng,
131+
model,
132+
varinfo,
133+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
134+
)
120135
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121136
varname_vals = mapreduce(
122137
collect,
@@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248263
varinfo = DynamicPPL.VarInfo(model)
249264
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250265
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
266+
# Extract values from the chain
267+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
268+
# Resample any variables that are not present in `values_dict`, and
269+
# return the model's retval.
270+
retval, _ = DynamicPPL.init!!(
271+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
272+
)
273+
retval
258274
end
259275
end
260276

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,8 +1200,15 @@ function predict(
12001200
varinfo = DynamicPPL.VarInfo(model)
12011201
return map(chain) do params_varinfo
12021202
vi = deepcopy(varinfo)
1203-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1204-
model(rng, vi)
1203+
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1204+
# parameters and one to set them. The reason why we need values_as_in_model
1205+
# is because `params_varinfo` may well have some weird combination of
1206+
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1207+
# freshly constructed.
1208+
# This is quite inefficient. It would of course be alright if
1209+
# ValuesAsInModelAccumulator was a default acc.
1210+
values_nt = values_as_in_model(model, false, params_varinfo)
1211+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
12051212
return vi
12061213
end
12071214
end

src/varinfo.jl

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,42 +1498,6 @@ function islinked(vi::VarInfo)
14981498
return any(istrans(vi, vn) for vn in keys(vi))
14991499
end
15001500

1501-
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
1502-
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1503-
end
1504-
function nested_setindex_maybe!(
1505-
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
1506-
) where {names,sym}
1507-
return if sym in names
1508-
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1509-
else
1510-
nothing
1511-
end
1512-
end
1513-
function _nested_setindex_maybe!(
1514-
vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName
1515-
)
1516-
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
1517-
vns = Base.keys(md)
1518-
if vn in vns
1519-
setindex!(vi, val, vn)
1520-
return vn
1521-
end
1522-
1523-
# Otherwise, we need to check if either of the `vns` subsumes `vn`.
1524-
i = findfirst(Base.Fix2(subsumes, vn), vns)
1525-
i === nothing && return nothing
1526-
1527-
vn_parent = vns[i]
1528-
val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here.
1529-
# Split the varname into its tail optic.
1530-
optic = remove_parent_optic(vn_parent, vn)
1531-
# Update the value for the parent.
1532-
val_parent_updated = set!!(val_parent, optic, val)
1533-
setindex!(vi, val_parent_updated, vn_parent)
1534-
return vn_parent
1535-
end
1536-
15371501
# The default getindex & setindex!() for get & set values
15381502
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
15391503
function getindex(vi::VarInfo, vn::VarName)
@@ -1978,113 +1942,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
19781942
return indices
19791943
end
19801944

1981-
"""
1982-
setval_and_resample!(vi::VarInfo, x)
1983-
setval_and_resample!(vi::VarInfo, values, keys)
1984-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
1985-
1986-
Set the values in `vi` to the provided values and those which are not present
1987-
in `x` or `chains` to *be* resampled.
1988-
1989-
Note that this does *not* resample the values not provided! It will call
1990-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
1991-
that the next time we call `model(vi)` these variables will be resampled.
1992-
1993-
## Note
1994-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
1995-
1996-
## Example
1997-
```jldoctest
1998-
julia> using DynamicPPL, Distributions, StableRNGs
1999-
2000-
julia> @model function demo(x)
2001-
m ~ Normal()
2002-
for i in eachindex(x)
2003-
x[i] ~ Normal(m, 1)
2004-
end
2005-
end;
2006-
2007-
julia> rng = StableRNG(42);
2008-
2009-
julia> m = demo([missing]);
2010-
2011-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2012-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2013-
2014-
julia> var_info[@varname(m)]
2015-
-0.6702516921145671
2016-
2017-
julia> var_info[@varname(x[1])]
2018-
-0.22312984965118443
2019-
2020-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2021-
2022-
julia> var_info[@varname(m)] # [✓] changed
2023-
100.0
2024-
2025-
julia> var_info[@varname(x[1])] # [✓] unchanged
2026-
-0.22312984965118443
2027-
2028-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2029-
2030-
julia> var_info[@varname(m)] # [✓] unchanged
2031-
100.0
2032-
2033-
julia> var_info[@varname(x[1])] # [✓] changed
2034-
101.37363069798343
2035-
```
2036-
2037-
## See also
2038-
- [`setval!`](@ref)
2039-
"""
2040-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2041-
return setval_and_resample!(vi, values(x), keys(x))
2042-
end
2043-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2044-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2045-
end
2046-
function setval_and_resample!(
2047-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2048-
)
2049-
if supports_varname_indexing(chains)
2050-
# First we need to set every variable to be resampled.
2051-
for vn in keys(vi)
2052-
set_flag!(vi, vn, "del")
2053-
end
2054-
# Then we set the variables in `varinfo` from `chain`.
2055-
for vn in varnames(chains)
2056-
vn_updated = nested_setindex_maybe!(
2057-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2058-
)
2059-
2060-
# Unset the `del` flag if we found something.
2061-
if vn_updated !== nothing
2062-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2063-
unset_flag!(vi, vn_updated, "del")
2064-
end
2065-
end
2066-
else
2067-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2068-
end
2069-
end
2070-
2071-
function _setval_and_resample_kernel!(
2072-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2073-
)
2074-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2075-
if !isempty(indices)
2076-
val = reduce(vcat, values[indices])
2077-
setval!(vi, val, vn)
2078-
settrans!!(vi, false, vn)
2079-
else
2080-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2081-
# the model on `vi` again.
2082-
set_flag!(vi, vn, "del")
2083-
end
2084-
2085-
return indices
2086-
end
2087-
20881945
values_as(vi::VarInfo) = vi.metadata
20891946
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
20901947
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
@model demo() = x ~ Normal()
33
model = demo()
44

5-
chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y]))
5+
chain = MCMCChains.Chains(
6+
randn(1000, 2, 1),
7+
[:x, :y],
8+
Dict(:internals => [:y]);
9+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
10+
)
611
chain_generated = @test_nowarn returned(model, chain)
712
@test size(chain_generated) == (1000, 1)
813
@test mean(chain_generated) 0 atol = 0.1

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
573573
xs_train = 1:0.1:10
574574
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
575575
m_lin_reg = linear_reg(xs_train, ys_train)
576-
chain = [VarInfo(m_lin_reg) _ in 1:10000]
576+
chain = [VarInfo(m_lin_reg) for _ in 1:10000]
577577

578578
# chain is generated from the prior
579579
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8181
varnames = collect(varnames)
8282
# Construct matrix of values
8383
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
84+
# Construct dict of varnames -> symbol
85+
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
8486
# Construct and return the Chains object
85-
return Chains(vals, varnames)
87+
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
8688
end
8789
function make_chain_from_prior(model::Model, n_iters::Int)
8890
return make_chain_from_prior(Random.default_rng(), model, n_iters)

test/varinfo.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ end
278278
@test typed_vi[vn_y] == 2.0
279279
end
280280

281-
@testset "setval! & setval_and_resample!" begin
281+
@testset "setval!" begin
282282
@model function testmodel(x)
283283
n = length(x)
284284
s ~ truncated(Normal(); lower=0)
@@ -329,8 +329,8 @@ end
329329
else
330330
DynamicPPL.setval!(vicopy, (m=zeros(5),))
331331
end
332-
# Setting `m` fails for univariate due to limitations of `setval!`
333-
# and `setval_and_resample!`. See docstring of `setval!` for more info.
332+
# Setting `m` fails for univariate due to limitations of `setval!`.
333+
# See docstring of `setval!` for more info.
334334
if model == model_uv && vi in [vi_untyped, vi_typed]
335335
@test_broken vicopy[m_vns] == zeros(5)
336336
else
@@ -355,57 +355,6 @@ end
355355
DynamicPPL.setval!(vicopy, (s=42,))
356356
@test vicopy[m_vns] == 1:5
357357
@test vicopy[s_vns] == 42
358-
359-
### `setval_and_resample!` ###
360-
if model == model_mv && vi == vi_untyped
361-
# Trying to re-run model with `MvNormal` on `vi_untyped` will call
362-
# `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError`
363-
# so we skip this particular case.
364-
continue
365-
end
366-
367-
if vi in [vi_vnv, vi_vnv_typed]
368-
# `setval_and_resample!` works differently for `VarNamedVector`: All
369-
# values will be resampled when model(vicopy) is called. Hence the below
370-
# tests are not applicable.
371-
continue
372-
end
373-
374-
vicopy = deepcopy(vi)
375-
DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),))
376-
model(vicopy)
377-
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
378-
if model == model_uv
379-
@test_broken vicopy[m_vns] == zeros(5)
380-
else
381-
@test vicopy[m_vns] == zeros(5)
382-
end
383-
@test vicopy[s_vns] != vi[s_vns]
384-
385-
# Ordering is NOT preserved.
386-
DynamicPPL.setval_and_resample!(
387-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
388-
)
389-
model(vicopy)
390-
if model == model_uv
391-
@test vicopy[m_vns] == 1:5
392-
else
393-
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
394-
end
395-
@test vicopy[s_vns] != vi[s_vns]
396-
397-
# Correct ordering.
398-
DynamicPPL.setval_and_resample!(
399-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
400-
)
401-
model(vicopy)
402-
@test vicopy[m_vns] == 1:5
403-
@test vicopy[s_vns] != vi[s_vns]
404-
405-
DynamicPPL.setval_and_resample!(vicopy, (s=42,))
406-
model(vicopy)
407-
@test vicopy[m_vns] != 1:5
408-
@test vicopy[s_vns] == 42
409358
end
410359
end
411360

@@ -419,9 +368,6 @@ end
419368
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
420369
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
421370
@test vals_prev == vi.metadata.x.vals
422-
423-
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
424-
@test vals_prev == vi.metadata.x.vals
425371
end
426372

427373
@testset "setval! on chain" begin

0 commit comments

Comments
 (0)