Skip to content

Commit b9f427e

Browse files
committed
Use ParamsInit for predict; remove setval_and_resample! and friends
1 parent 24a6453 commit b9f427e

File tree

7 files changed

+49
-216
lines changed

7 files changed

+49
-216
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function _check_varname_indexing(c::MCMCChains.Chains)
3030
return DynamicPPL.supports_varname_indexing(c) ||
31-
error("Chains do not support indexing using `VarName`s.")
31+
error("This `Chains` object does not support indexing using `VarName`s.")
3232
end
3333

3434
function DynamicPPL.getindex_varname(
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{DynamicPPL.VarName,Any}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,15 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
128+
# Resample any variables that are not present in `values_dict`
129+
_, varinfo = DynamicPPL.init!!(
130+
rng,
131+
model,
132+
varinfo,
133+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
134+
)
120135
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121136
varname_vals = mapreduce(
122137
collect,
@@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248263
varinfo = DynamicPPL.VarInfo(model)
249264
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250265
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
266+
# Extract values from the chain
267+
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
268+
# Resample any variables that are not present in `values_dict`, and
269+
# return the model's retval.
270+
retval, _ = DynamicPPL.init!!(
271+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
272+
)
273+
retval
258274
end
259275
end
260276

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,8 +1209,15 @@ function predict(
12091209
varinfo = DynamicPPL.VarInfo(model)
12101210
return map(chain) do params_varinfo
12111211
vi = deepcopy(varinfo)
1212-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1213-
model(rng, vi)
1212+
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1213+
# parameters and one to set them. The reason why we need values_as_in_model
1214+
# is because `params_varinfo` may well have some weird combination of
1215+
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1216+
# freshly constructed.
1217+
# This is quite inefficient. It would of course be alright if
1218+
# ValuesAsInModelAccumulator was a default acc.
1219+
values_nt = values_as_in_model(model, false, params_varinfo)
1220+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
12141221
return vi
12151222
end
12161223
end

src/varinfo.jl

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,42 +1514,6 @@ function islinked(vi::VarInfo)
15141514
return any(istrans(vi, vn) for vn in keys(vi))
15151515
end
15161516

1517-
function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
1518-
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1519-
end
1520-
function nested_setindex_maybe!(
1521-
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
1522-
) where {names,sym}
1523-
return if sym in names
1524-
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
1525-
else
1526-
nothing
1527-
end
1528-
end
1529-
function _nested_setindex_maybe!(
1530-
vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName
1531-
)
1532-
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
1533-
vns = Base.keys(md)
1534-
if vn in vns
1535-
setindex!(vi, val, vn)
1536-
return vn
1537-
end
1538-
1539-
# Otherwise, we need to check if either of the `vns` subsumes `vn`.
1540-
i = findfirst(Base.Fix2(subsumes, vn), vns)
1541-
i === nothing && return nothing
1542-
1543-
vn_parent = vns[i]
1544-
val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here.
1545-
# Split the varname into its tail optic.
1546-
optic = remove_parent_optic(vn_parent, vn)
1547-
# Update the value for the parent.
1548-
val_parent_updated = set!!(val_parent, optic, val)
1549-
setindex!(vi, val_parent_updated, vn_parent)
1550-
return vn_parent
1551-
end
1552-
15531517
# The default getindex & setindex!() for get & set values
15541518
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
15551519
function getindex(vi::VarInfo, vn::VarName)
@@ -1972,113 +1936,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
19721936
return indices
19731937
end
19741938

1975-
"""
1976-
setval_and_resample!(vi::VarInfo, x)
1977-
setval_and_resample!(vi::VarInfo, values, keys)
1978-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
1979-
1980-
Set the values in `vi` to the provided values and those which are not present
1981-
in `x` or `chains` to *be* resampled.
1982-
1983-
Note that this does *not* resample the values not provided! It will call
1984-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
1985-
that the next time we call `model(vi)` these variables will be resampled.
1986-
1987-
## Note
1988-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
1989-
1990-
## Example
1991-
```jldoctest
1992-
julia> using DynamicPPL, Distributions, StableRNGs
1993-
1994-
julia> @model function demo(x)
1995-
m ~ Normal()
1996-
for i in eachindex(x)
1997-
x[i] ~ Normal(m, 1)
1998-
end
1999-
end;
2000-
2001-
julia> rng = StableRNG(42);
2002-
2003-
julia> m = demo([missing]);
2004-
2005-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2006-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2007-
2008-
julia> var_info[@varname(m)]
2009-
-0.6702516921145671
2010-
2011-
julia> var_info[@varname(x[1])]
2012-
-0.22312984965118443
2013-
2014-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2015-
2016-
julia> var_info[@varname(m)] # [✓] changed
2017-
100.0
2018-
2019-
julia> var_info[@varname(x[1])] # [✓] unchanged
2020-
-0.22312984965118443
2021-
2022-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2023-
2024-
julia> var_info[@varname(m)] # [✓] unchanged
2025-
100.0
2026-
2027-
julia> var_info[@varname(x[1])] # [✓] changed
2028-
101.37363069798343
2029-
```
2030-
2031-
## See also
2032-
- [`setval!`](@ref)
2033-
"""
2034-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2035-
return setval_and_resample!(vi, values(x), keys(x))
2036-
end
2037-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2038-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2039-
end
2040-
function setval_and_resample!(
2041-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2042-
)
2043-
if supports_varname_indexing(chains)
2044-
# First we need to set every variable to be resampled.
2045-
for vn in keys(vi)
2046-
set_flag!(vi, vn, "del")
2047-
end
2048-
# Then we set the variables in `varinfo` from `chain`.
2049-
for vn in varnames(chains)
2050-
vn_updated = nested_setindex_maybe!(
2051-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2052-
)
2053-
2054-
# Unset the `del` flag if we found something.
2055-
if vn_updated !== nothing
2056-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2057-
unset_flag!(vi, vn_updated, "del")
2058-
end
2059-
end
2060-
else
2061-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2062-
end
2063-
end
2064-
2065-
function _setval_and_resample_kernel!(
2066-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2067-
)
2068-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2069-
if !isempty(indices)
2070-
val = reduce(vcat, values[indices])
2071-
setval!(vi, val, vn)
2072-
settrans!!(vi, false, vn)
2073-
else
2074-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2075-
# the model on `vi` again.
2076-
set_flag!(vi, vn, "del")
2077-
end
2078-
2079-
return indices
2080-
end
2081-
20821939
values_as(vi::VarInfo) = vi.metadata
20831940
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
20841941
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
@model demo() = x ~ Normal()
33
model = demo()
44

5-
chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y]))
5+
chain = MCMCChains.Chains(
6+
randn(1000, 2, 1),
7+
[:x, :y],
8+
Dict(:internals => [:y]);
9+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
10+
)
611
chain_generated = @test_nowarn returned(model, chain)
712
@test size(chain_generated) == (1000, 1)
813
@test mean(chain_generated) 0 atol = 0.1

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
580580
xs_train = 1:0.1:10
581581
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
582582
m_lin_reg = linear_reg(xs_train, ys_train)
583-
chain = [VarInfo(m_lin_reg) _ in 1:10000]
583+
chain = [VarInfo(m_lin_reg) for _ in 1:10000]
584584

585585
# chain is generated from the prior
586586
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8787
varnames = collect(varnames)
8888
# Construct matrix of values
8989
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
90+
# Construct dict of varnames -> symbol
91+
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
9092
# Construct and return the Chains object
91-
return Chains(vals, varnames)
93+
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
9294
end
9395
function make_chain_from_prior(model::Model, n_iters::Int)
9496
return make_chain_from_prior(Random.default_rng(), model, n_iters)

test/varinfo.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ end
325325
@test typed_vi[vn_y] == 2.0
326326
end
327327

328-
@testset "setval! & setval_and_resample!" begin
328+
@testset "setval!" begin
329329
@model function testmodel(x)
330330
n = length(x)
331331
s ~ truncated(Normal(); lower=0)
@@ -376,8 +376,8 @@ end
376376
else
377377
DynamicPPL.setval!(vicopy, (m=zeros(5),))
378378
end
379-
# Setting `m` fails for univariate due to limitations of `setval!`
380-
# and `setval_and_resample!`. See docstring of `setval!` for more info.
379+
# Setting `m` fails for univariate due to limitations of `setval!`.
380+
# See docstring of `setval!` for more info.
381381
if model == model_uv && vi in [vi_untyped, vi_typed]
382382
@test_broken vicopy[m_vns] == zeros(5)
383383
else
@@ -402,57 +402,6 @@ end
402402
DynamicPPL.setval!(vicopy, (s=42,))
403403
@test vicopy[m_vns] == 1:5
404404
@test vicopy[s_vns] == 42
405-
406-
### `setval_and_resample!` ###
407-
if model == model_mv && vi == vi_untyped
408-
# Trying to re-run model with `MvNormal` on `vi_untyped` will call
409-
# `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError`
410-
# so we skip this particular case.
411-
continue
412-
end
413-
414-
if vi in [vi_vnv, vi_vnv_typed]
415-
# `setval_and_resample!` works differently for `VarNamedVector`: All
416-
# values will be resampled when model(vicopy) is called. Hence the below
417-
# tests are not applicable.
418-
continue
419-
end
420-
421-
vicopy = deepcopy(vi)
422-
DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),))
423-
model(vicopy)
424-
# Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)`
425-
if model == model_uv
426-
@test_broken vicopy[m_vns] == zeros(5)
427-
else
428-
@test vicopy[m_vns] == zeros(5)
429-
end
430-
@test vicopy[s_vns] != vi[s_vns]
431-
432-
# Ordering is NOT preserved.
433-
DynamicPPL.setval_and_resample!(
434-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)
435-
)
436-
model(vicopy)
437-
if model == model_uv
438-
@test vicopy[m_vns] == 1:5
439-
else
440-
@test vicopy[m_vns] == [1, 3, 5, 4, 2]
441-
end
442-
@test vicopy[s_vns] != vi[s_vns]
443-
444-
# Correct ordering.
445-
DynamicPPL.setval_and_resample!(
446-
vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...)
447-
)
448-
model(vicopy)
449-
@test vicopy[m_vns] == 1:5
450-
@test vicopy[s_vns] != vi[s_vns]
451-
452-
DynamicPPL.setval_and_resample!(vicopy, (s=42,))
453-
model(vicopy)
454-
@test vicopy[m_vns] != 1:5
455-
@test vicopy[s_vns] == 42
456405
end
457406
end
458407

@@ -466,9 +415,6 @@ end
466415
ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])]
467416
DynamicPPL.setval!(vi, vi.metadata.x.vals, ks)
468417
@test vals_prev == vi.metadata.x.vals
469-
470-
DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks)
471-
@test vals_prev == vi.metadata.x.vals
472418
end
473419

474420
@testset "setval! on chain" begin

0 commit comments

Comments
 (0)