Skip to content

Commit 938a69d

Browse files
authored
Restrict values_as_in_model API (#778)
1 parent e673b69 commit 938a69d

File tree

3 files changed

+16
-43
lines changed

3 files changed

+16
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.33.1"
3+
version = "0.34.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/values_as_in_model.jl

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space.
1919
# Fields
2020
$(TYPEDFIELDS)
2121
"""
22-
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
22+
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
2323
"values that are extracted from the model"
24-
values::T
24+
values::OrderedDict
2525
"whether to extract variables on the LHS of :="
2626
include_colon_eq::Bool
2727
"child context"
@@ -114,34 +114,32 @@ function dot_tilde_assume(
114114
end
115115

116116
"""
117-
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
118-
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
117+
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
119118
120119
Get the values of `varinfo` as they would be seen in the model.
121120
122-
If no `varinfo` is provided, then this is effectively the same as
123-
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
121+
More specifically, this method attempts to extract the realization _as seen in
122+
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
123+
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
124+
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
125+
working in unconstrained space.
124126
125-
More specifically, this method attempts to extract the realization _as seen in the model_.
126-
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
127-
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
128-
space.
129-
130-
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
131-
of additional model evaluations.
127+
Hence this method is a "safe" way of obtaining realizations in constrained
128+
space at the cost of additional model evaluations.
132129
133130
# Arguments
134131
- `model::Model`: model to extract realizations from.
135132
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
136133
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
137-
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
138-
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
134+
- `context::AbstractContext`: base context to use for the extraction. Defaults
135+
to `DynamicPPL.DefaultContext()`.
139136
140137
# Examples
141138
142139
## When `VarInfo` fails
143140
144-
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
141+
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
142+
and constrained variables.
145143
146144
```jldoctest
147145
julia> using Distributions, StableRNGs
@@ -191,19 +189,10 @@ true
191189
function values_as_in_model(
192190
model::Model,
193191
include_colon_eq::Bool,
194-
varinfo::AbstractVarInfo=VarInfo(),
192+
varinfo::AbstractVarInfo,
195193
context::AbstractContext=DefaultContext(),
196194
)
197195
context = ValuesAsInModelContext(include_colon_eq, context)
198196
evaluate!!(model, varinfo, context)
199197
return context.values
200198
end
201-
function values_as_in_model(
202-
rng::Random.AbstractRNG,
203-
model::Model,
204-
include_colon_eq::Bool,
205-
varinfo::AbstractVarInfo=VarInfo(),
206-
context::AbstractContext=DefaultContext(),
207-
)
208-
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
209-
end

test/model.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
429429
end
430430
end
431431
end
432-
433-
@testset "check that sampling obeys rng if passed" begin
434-
@model function f()
435-
x ~ Normal(0)
436-
return y ~ Normal(x)
437-
end
438-
model = f()
439-
# Call values_as_in_model with the rng
440-
values = values_as_in_model(Random.Xoshiro(43), model, false)
441-
# Check that they match the values that would be used if vi was seeded
442-
# with that seed instead
443-
expected_vi = VarInfo(Random.Xoshiro(43), model)
444-
for vn in keys(values)
445-
@test values[vn] == expected_vi[vn]
446-
end
447-
end
448432
end
449433

450434
@testset "Erroneous model call" begin

0 commit comments

Comments
 (0)