Skip to content

Commit 2d6ef3f

Browse files
committed
Resample variable if not given in setval! (#216)
Currently if one calls `DynamicPPL._setval!(vi, vi.metadata, values, keys)` , then only those values present in `keys` will be set, as expected, but the variables which are _not_ present in `keys` will simply be left as-is. This means that we get the following behavior: ``` julia julia> using Turing julia> @model function demo(x) m ~ Normal(0, 1) for i in eachindex(x) x[i] ~ Normal(m, 1) end end demo (generic function with 1 method) julia> m_missing = demo(fill(missing, 2)); julia> var_info_missing = DynamicPPL.VarInfo(m_missing); julia> var_info_missing.metadata.m.vals 1-element Array{Float64,1}: 0.7251417347423874 julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> var_info_missing.metadata.m.vals # ✓ new value 1-element Array{Float64,1}: 0.0 julia> var_info_missing.metadata.x.vals # ✓ still the same value 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> m_missing(var_info_missing) # Re-run the model with new value for `m` julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`! 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 ``` _Personally_ I expected `x` to be resampled since now parts of the model has changed and thus the sample `x` is no longer representative of a sample from the model (under the sampler used). This PR "fixes" the above so that you get the following behavior: ``` julia julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, )); julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> m_missing(var_info_missing) julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: -2.0493130638394947 0.3881955730968598 ``` This was discoverd when debugging TuringLang/Turing.jl#1352 as I want to move `Turing.predict` over to using `DynamicPPL.setval!` and it also has consequences for `DynamicPPL.generated_quantities` which uses `DynamicPPL.setval!` under the hood and thus suffer from the same issue. There's an alternative: instead of making this the default-behavior, we could add `kwargs...` to `setval!` which includes `resample_missing::Bool` or something. I'm also completely fine with a solution like that 👍
1 parent 3602c56 commit 2d6ef3f

File tree

8 files changed

+320
-56
lines changed

8 files changed

+320
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.8"
3+
version = "0.10.9"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/loglikelihoods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function pointwise_loglikelihoods(
173173
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
174174
for (sample_idx, chain_idx) in iters
175175
# Update the values
176-
setval!(vi, chain, sample_idx, chain_idx)
176+
setval_and_resample!(vi, chain, sample_idx, chain_idx)
177177

178178
# Execute model
179179
model(vi, spl, ctx)

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ function generated_quantities(model::Model, chain::AbstractChains)
277277
varinfo = VarInfo(model)
278278
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
279279
return map(iters) do (sample_idx, chain_idx)
280-
setval!(varinfo, chain, sample_idx, chain_idx)
280+
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
281281
model(varinfo)
282282
end
283283
end

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,10 @@ end
153153
function inittrans(rng, dist::MatrixDistribution, n::Int)
154154
return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n])
155155
end
156+
157+
158+
#######################
159+
# Convenience methods #
160+
#######################
161+
collectmaybe(x) = x
162+
collectmaybe(x::Base.AbstractSet) = collect(x)

src/varinfo.jl

Lines changed: 211 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,22 @@ end
584584

585585
# Functions defined only for UntypedVarInfo
586586
"""
587-
keys(vi::UntypedVarInfo)
587+
keys(vi::AbstractVarInfo)
588588
589589
Return an iterator over all `vns` in `vi`.
590590
"""
591-
keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
591+
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
592+
593+
@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
594+
expr = Expr(:call)
595+
push!(expr.args, :vcat)
596+
597+
for n in names
598+
push!(expr.args, :(vi.metadata.$n.vns))
599+
end
600+
601+
return expr
602+
end
592603

593604
"""
594605
setgid!(vi::VarInfo, gid::Selector, vn::VarName)
@@ -1165,19 +1176,39 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
11651176
end
11661177
end
11671178

1168-
setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
1169-
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
1170-
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1171-
end
1179+
# TODO: Maybe rename or something?
1180+
"""
1181+
_apply!(kernel!, vi::AbstractVarInfo, values, keys)
1182+
1183+
Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`.
1184+
"""
1185+
function _apply!(kernel!, vi::AbstractVarInfo, values, keys)
1186+
keys_strings = map(string, collectmaybe(keys))
1187+
num_indices_seen = 0
11721188

1173-
function _setval!(vi::AbstractVarInfo, values, keys)
11741189
for vn in Base.keys(vi)
1175-
_setval_kernel!(vi, vn, values, keys)
1190+
indices_found = kernel!(vi, vn, values, keys_strings)
1191+
if indices_found !== nothing
1192+
num_indices_seen += length(indices_found)
1193+
end
11761194
end
1195+
1196+
if length(keys) > num_indices_seen
1197+
# Some keys have not been seen, i.e. attempted to set variables which
1198+
# we were not able to locate in `vi`.
1199+
# Find the ones we missed so we can warn the user.
1200+
unused_keys = _find_missing_keys(vi, keys_strings)
1201+
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
1202+
end
1203+
11771204
return vi
11781205
end
1179-
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
1180-
@generated function _typed_setval!(
1206+
1207+
_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!(
1208+
kernel!, vi, vi.metadata, values, collectmaybe(keys))
1209+
1210+
@generated function _typed_apply!(
1211+
kernel!,
11811212
vi::TypedVarInfo,
11821213
metadata::NamedTuple{names},
11831214
values,
@@ -1186,30 +1217,189 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
11861217
updates = map(names) do n
11871218
quote
11881219
for vn in metadata.$n.vns
1189-
_setval_kernel!(vi, vn, values, keys)
1220+
indices_found = kernel!(vi, vn, values, keys_strings)
1221+
if indices_found !== nothing
1222+
num_indices_seen += length(indices_found)
1223+
end
11901224
end
11911225
end
11921226
end
1193-
1227+
11941228
return quote
1229+
keys_strings = map(string, keys)
1230+
num_indices_seen = 0
1231+
11951232
$(updates...)
1233+
1234+
if length(keys) > num_indices_seen
1235+
# Some keys have not been seen, i.e. attempted to set variables which
1236+
# we were not able to locate in `vi`.
1237+
# Find the ones we missed so we can warn the user.
1238+
unused_keys = _find_missing_keys(vi, keys_strings)
1239+
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
1240+
end
1241+
11961242
return vi
11971243
end
11981244
end
11991245

1246+
function _find_missing_keys(vi::AbstractVarInfo, keys)
1247+
string_vns = map(string, collectmaybe(Base.keys(vi)))
1248+
# If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`.
1249+
missing_keys = filter(keys) do key
1250+
!any(Base.Fix2(subsumes_string, key), string_vns)
1251+
end
1252+
1253+
return missing_keys
1254+
end
1255+
1256+
"""
1257+
setval!(vi::AbstractVarInfo, x)
1258+
setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
1259+
1260+
Set the values in `vi` to the provided values and leave those which are not present in
1261+
`x` or `chains` unchanged.
1262+
1263+
## Notes
1264+
This is rather limited for two reasons:
1265+
1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood,
1266+
and therefore suffers from the same limitations as [`subsumes_string`](@ref).
1267+
2. It will set every `vn` present in `keys`. It will NOT however
1268+
set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`,
1269+
representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will
1270+
be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`.
1271+
1272+
## Example
1273+
```jldoctest
1274+
julia> using DynamicPPL, Distributions, StableRNGs
1275+
1276+
julia> @model function demo(x)
1277+
m ~ Normal()
1278+
for i in eachindex(x)
1279+
x[i] ~ Normal(m, 1)
1280+
end
1281+
end;
1282+
1283+
julia> rng = StableRNG(42);
1284+
1285+
julia> m = demo([missing]);
1286+
1287+
julia> var_info = DynamicPPL.VarInfo(rng, m);
1288+
1289+
julia> var_info[@varname(m)]
1290+
-0.6702516921145671
1291+
1292+
julia> var_info[@varname(x[1])]
1293+
-0.22312984965118443
1294+
1295+
julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]`
1296+
1297+
julia> var_info[@varname(m)] # [✓] changed
1298+
100.0
1299+
1300+
julia> var_info[@varname(x[1])] # [✓] unchanged
1301+
-0.22312984965118443
1302+
1303+
julia> m(rng, var_info); # rerun model
1304+
1305+
julia> var_info[@varname(m)] # [✓] unchanged
1306+
100.0
1307+
1308+
julia> var_info[@varname(x[1])] # [✓] unchanged
1309+
-0.22312984965118443
1310+
```
1311+
"""
1312+
setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x))
1313+
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
1314+
return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1315+
end
1316+
12001317
function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1201-
string_vn = string(vn)
1202-
string_vn_indexing = string_vn * "["
1203-
indices = findall(keys) do x
1204-
string_x = string(x)
1205-
return string_x == string_vn || startswith(string_x, string_vn_indexing)
1318+
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
1319+
if !isempty(indices)
1320+
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
1321+
val = reduce(vcat, values[sorted_indices])
1322+
setval!(vi, val, vn)
1323+
settrans!(vi, false, vn)
12061324
end
1325+
1326+
return indices
1327+
end
1328+
1329+
"""
1330+
setval_and_resample!(vi::AbstractVarInfo, x)
1331+
setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx)
1332+
1333+
Set the values in `vi` to the provided values and those which are not present
1334+
in `x` or `chains` to *be* resampled.
1335+
1336+
Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")`
1337+
for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these
1338+
variables will be resampled.
1339+
1340+
## Note
1341+
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
1342+
1343+
## Example
1344+
```jldoctest
1345+
julia> using DynamicPPL, Distributions, StableRNGs
1346+
1347+
julia> @model function demo(x)
1348+
m ~ Normal()
1349+
for i in eachindex(x)
1350+
x[i] ~ Normal(m, 1)
1351+
end
1352+
end;
1353+
1354+
julia> rng = StableRNG(42);
1355+
1356+
julia> m = demo([missing]);
1357+
1358+
julia> var_info = DynamicPPL.VarInfo(rng, m);
1359+
1360+
julia> var_info[@varname(m)]
1361+
-0.6702516921145671
1362+
1363+
julia> var_info[@varname(x[1])]
1364+
-0.22312984965118443
1365+
1366+
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
1367+
1368+
julia> var_info[@varname(m)] # [✓] changed
1369+
100.0
1370+
1371+
julia> var_info[@varname(x[1])] # [✓] unchanged
1372+
-0.22312984965118443
1373+
1374+
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
1375+
1376+
julia> var_info[@varname(m)] # [✓] unchanged
1377+
100.0
1378+
1379+
julia> var_info[@varname(x[1])] # [✓] changed
1380+
101.37363069798343
1381+
```
1382+
1383+
## See also
1384+
- [`setval!`](@ref)
1385+
"""
1386+
setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x))
1387+
function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
1388+
return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1389+
end
1390+
1391+
function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1392+
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
12071393
if !isempty(indices)
1208-
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
1209-
val = mapreduce(vcat, sorted_indices) do i
1210-
values[i]
1211-
end
1394+
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
1395+
val = reduce(vcat, values[sorted_indices])
12121396
setval!(vi, val, vn)
12131397
settrans!(vi, false, vn)
1398+
else
1399+
# Ensures that we'll resample the variable corresponding to `vn` if we run
1400+
# the model on `vi` again.
1401+
set_flag!(vi, vn, "del")
12141402
end
1403+
1404+
return indices
12151405
end

src/varname.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
"""
2+
subsumes_string(u::String, v::String[, u_indexing])
3+
4+
Check whether stringified variable name `v` describes a sub-range of stringified variable `u`.
5+
6+
This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting:
7+
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
8+
9+
## Note
10+
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
11+
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
12+
and similarly to `v`. But this is slow.
13+
"""
14+
function subsumes_string(u::String, v::String, u_indexing=u * "[")
15+
return u == v || startswith(v, u_indexing)
16+
end
17+
118
"""
219
inargnames(varname::VarName, model::Model)
320

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1212
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
15+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -26,6 +27,7 @@ Documenter = "0.26.1"
2627
ForwardDiff = "0.10.12"
2728
MCMCChains = "4.0.4"
2829
MacroTools = "0.5.5"
30+
StableRNGs = "1"
2931
Tracker = "0.2.11"
3032
Zygote = "0.5.4, 0.6"
3133
julia = "1.3"

0 commit comments

Comments
 (0)