Skip to content

Commit c5653ba

Browse files
committed
Fix a bunch of tests
1 parent 5c7a156 commit c5653ba

File tree

11 files changed

+223
-339
lines changed

11 files changed

+223
-339
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242
return keys(c.info.varname_to_symbol)
4343
end
4444

45+
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
46+
_check_varname_indexing(c)
47+
d = Dict{VarName}()
48+
for vn in DynamicPPL.varnames(c)
49+
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
50+
end
51+
return d
52+
end
53+
4554
"""
4655
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
4756
@@ -114,9 +123,19 @@ function DynamicPPL.predict(
114123

115124
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116125
predictive_samples = map(iters) do (sample_idx, chain_idx)
117-
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))
119-
126+
# Extract values from the chain
127+
values_dict = DynamicPPL.chain_sample_to_varname_dict(
128+
parameter_only_chain, sample_idx, chain_idx
129+
)
130+
# Resample any variables that are not present in `values_dict`
131+
_, varinfo = last(
132+
DynamicPPL.init!!(
133+
rng,
134+
model,
135+
varinfo,
136+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
137+
),
138+
)
120139
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121140
varname_vals = mapreduce(
122141
collect,
@@ -248,13 +267,20 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
248267
varinfo = DynamicPPL.VarInfo(model)
249268
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
250269
return map(iters) do (sample_idx, chain_idx)
251-
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
252-
# Update the varinfo with the current sample and make variables not present in `chain`
253-
# to be sampled.
254-
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255-
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256-
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
270+
# Extract values from the chain
271+
values_dict = DynamicPPL.chain_sample_to_varname_dict(
272+
parameter_only_chain, sample_idx, chain_idx
273+
)
274+
# Resample any variables that are not present in `values_dict`, and
275+
# return the model's retval (`first`).
276+
first(
277+
DynamicPPL.init!!(
278+
rng,
279+
model,
280+
varinfo,
281+
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
282+
),
283+
)
258284
end
259285
end
260286

src/debug_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,9 @@ function check_model_and_trace(
438438
kwargs...,
439439
)
440440
# Execute the model with the debug context.
441-
new_context = setleafcontext(model.context, InitContext(rng, Prior()))
441+
new_context = DynamicPPL.setleafcontext(
442+
model.context, DynamicPPL.InitContext(rng, DynamicPPL.PriorInit())
443+
)
442444
debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...)
443445
debug_model = DynamicPPL.contextualize(model, debug_context)
444446

src/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,8 +1165,8 @@ function predict(
11651165
varinfo = DynamicPPL.VarInfo(model)
11661166
return map(chain) do params_varinfo
11671167
vi = deepcopy(varinfo)
1168-
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1169-
model(rng, vi)
1168+
values_nt = values_as(params_varinfo, NamedTuple)
1169+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
11701170
return vi
11711171
end
11721172
end

src/test_utils/contexts.jl

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
2929
node_trait = DynamicPPL.NodeTrait(context)
3030
# Throw error immediately if it it's missing a `NodeTrait` implementation.
3131
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
32-
throw(ValueError("Invalid NodeTrait: $node_trait"))
32+
error("Invalid NodeTrait: $node_trait")
3333

34-
# To see change, let's make sure we're using a different leaf context than the current.
35-
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
36-
DynamicPPL.DynamicTransformationContext{false}()
34+
if node_trait isa DynamicPPL.IsLeaf
35+
test_leaf_context(context, model)
3736
else
38-
DefaultContext()
37+
test_parent_context(context, model)
3938
end
40-
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
41-
leafcontext_new
39+
end
40+
41+
function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
42+
@test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf
43+
44+
# Note that for a leaf context we can't assume that it will work with an
45+
# empty VarInfo. Thus we only test evaluation (i.e., assuming that the
46+
# varinfo already contains all necessary variables).
47+
@testset "evaluation" begin
48+
# Generate a new filled untyped varinfo
49+
_, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo())
50+
typed_vi = DynamicPPL.typed_varinfo(untyped_vi)
51+
new_model = contextualize(model, context)
52+
for vi in [untyped_vi, typed_vi]
53+
_, vi = DynamicPPL.evaluate!!(new_model, vi)
54+
@test vi isa DynamicPPL.VarInfo
55+
end
56+
end
57+
end
58+
59+
function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
60+
@test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
4261

43-
# The interface methods.
44-
if node_trait isa DynamicPPL.IsParent
45-
# `childcontext` and `setchildcontext`
46-
# With new child context
62+
@testset "{set,}{leaf,child}context" begin
63+
# Ensure we're using a different leaf context than the current.
64+
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
65+
DynamicPPL.DynamicTransformationContext{false}()
66+
else
67+
DefaultContext()
68+
end
69+
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
70+
leafcontext_new
4771
childcontext_new = TestParentContext()
4872
@test DynamicPPL.childcontext(
4973
DynamicPPL.setchildcontext(context, childcontext_new)
@@ -56,15 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
5680
leafcontext_new
5781
end
5882

59-
# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
60-
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
61-
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
62-
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
63-
# Untyped varinfo.
64-
varinfo_untyped = DynamicPPL.VarInfo()
65-
new_model = contextualize(model, context)
66-
@test DynamicPPL.evaluate!!(new_model, varinfo_untyped) isa Any
67-
# Typed varinfo.
68-
varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped)
69-
@test DynamicPPL.evaluate!!(new_model, varinfo_typed) isa Any
83+
@testset "initialisation and evaluation" begin
84+
new_model = contextualize(model, context)
85+
for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())]
86+
# Initialisation
87+
_, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo())
88+
@test vi isa DynamicPPL.VarInfo
89+
# Evaluation
90+
_, vi = DynamicPPL.evaluate!!(new_model, vi)
91+
@test vi isa DynamicPPL.VarInfo
92+
end
93+
end
7094
end

src/varinfo.jl

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,113 +2045,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke
20452045
return indices
20462046
end
20472047

2048-
"""
2049-
setval_and_resample!(vi::VarInfo, x)
2050-
setval_and_resample!(vi::VarInfo, values, keys)
2051-
setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx)
2052-
2053-
Set the values in `vi` to the provided values and those which are not present
2054-
in `x` or `chains` to *be* resampled.
2055-
2056-
Note that this does *not* resample the values not provided! It will call
2057-
`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means
2058-
that the next time we call `model(vi)` these variables will be resampled.
2059-
2060-
## Note
2061-
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
2062-
2063-
## Example
2064-
```jldoctest
2065-
julia> using DynamicPPL, Distributions, StableRNGs
2066-
2067-
julia> @model function demo(x)
2068-
m ~ Normal()
2069-
for i in eachindex(x)
2070-
x[i] ~ Normal(m, 1)
2071-
end
2072-
end;
2073-
2074-
julia> rng = StableRNG(42);
2075-
2076-
julia> m = demo([missing]);
2077-
2078-
julia> var_info = DynamicPPL.VarInfo(rng, m);
2079-
# Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set.
2080-
2081-
julia> var_info[@varname(m)]
2082-
-0.6702516921145671
2083-
2084-
julia> var_info[@varname(x[1])]
2085-
-0.22312984965118443
2086-
2087-
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
2088-
2089-
julia> var_info[@varname(m)] # [✓] changed
2090-
100.0
2091-
2092-
julia> var_info[@varname(x[1])] # [✓] unchanged
2093-
-0.22312984965118443
2094-
2095-
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
2096-
2097-
julia> var_info[@varname(m)] # [✓] unchanged
2098-
100.0
2099-
2100-
julia> var_info[@varname(x[1])] # [✓] changed
2101-
101.37363069798343
2102-
```
2103-
2104-
## See also
2105-
- [`setval!`](@ref)
2106-
"""
2107-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x)
2108-
return setval_and_resample!(vi, values(x), keys(x))
2109-
end
2110-
function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys)
2111-
return _apply!(_setval_and_resample_kernel!, vi, values, keys)
2112-
end
2113-
function setval_and_resample!(
2114-
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
2115-
)
2116-
if supports_varname_indexing(chains)
2117-
# First we need to set every variable to be resampled.
2118-
for vn in keys(vi)
2119-
set_flag!(vi, vn, "del")
2120-
end
2121-
# Then we set the variables in `varinfo` from `chain`.
2122-
for vn in varnames(chains)
2123-
vn_updated = nested_setindex_maybe!(
2124-
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
2125-
)
2126-
2127-
# Unset the `del` flag if we found something.
2128-
if vn_updated !== nothing
2129-
# NOTE: This will be triggered even if only a subset of a variable has been set!
2130-
unset_flag!(vi, vn_updated, "del")
2131-
end
2132-
end
2133-
else
2134-
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
2135-
end
2136-
end
2137-
2138-
function _setval_and_resample_kernel!(
2139-
vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys
2140-
)
2141-
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
2142-
if !isempty(indices)
2143-
val = reduce(vcat, values[indices])
2144-
setval!(vi, val, vn)
2145-
settrans!!(vi, false, vn)
2146-
else
2147-
# Ensures that we'll resample the variable corresponding to `vn` if we run
2148-
# the model on `vi` again.
2149-
set_flag!(vi, vn, "del")
2150-
end
2151-
2152-
return indices
2153-
end
2154-
21552048
values_as(vi::VarInfo) = vi.metadata
21562049
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
21572050
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})

test/compiler.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ module Issue537 end
194194
@test getlogjoint(varinfo) == lp
195195
@test varinfo_ isa AbstractVarInfo
196196
@test model_.f === model.f
197-
@test model_.context isa InitContext
197+
@test model_.context isa DynamicPPL.InitContext
198198
@test model_.context.rng isa Random.AbstractRNG
199199

200200
# disable warnings
@@ -595,13 +595,13 @@ module Issue537 end
595595
# an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`.
596596
@model empty_model() = return x = 1
597597
empty_vi = VarInfo()
598-
retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi)
598+
retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi)
599599
@test retval_and_vi isa Tuple{Int,typeof(empty_vi)}
600600

601601
# Even if the return-value is `AbstractVarInfo`, we should return
602602
# a `Tuple` with `AbstractVarInfo` in the second component too.
603603
@model demo() = return __varinfo__
604-
retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo())
604+
retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo())
605605
@test svi == SimpleVarInfo()
606606
if Threads.nthreads() > 1
607607
@test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo}
@@ -617,11 +617,11 @@ module Issue537 end
617617
f(x) = return x^2
618618
return f(1.0)
619619
end
620-
retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo())
620+
retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo())
621621
@test retval isa Float64
622622

623623
@model demo() = x ~ Normal()
624-
retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo())
624+
retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo())
625625

626626
# Return-value when using `to_submodel`
627627
@model inner() = x ~ Normal()

test/ext/DynamicPPLJETExt.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,6 @@
7070
)
7171
JET.test_call(f_eval, argtypes_eval)
7272

73-
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
74-
init_model, varinfo
75-
)
76-
JET.test_call(f_sample, argtypes_sample)
7773
# For our demo models, they should all result in typed.
7874
is_typed = varinfo isa DynamicPPL.NTVarInfo
7975
@test is_typed

0 commit comments

Comments
 (0)