Skip to content

Commit 7f0ff38

Browse files
committed
renamed extract_realizations to values_as_in_model to be a bit
more descriptive (and similarly for the corresponding context)
1 parent add411a commit 7f0ff38

File tree

4 files changed

+33
-38
lines changed

4 files changed

+33
-38
lines changed

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl
143143
extract_priors
144144
```
145145

146-
Safe extraction of realizations from a given [`AbstractVarInfo`](@ref) can be done using [`extract_realizations`](@ref).
146+
Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are seen in the model can be done using [`values_as_in_model`](@ref).
147147

148148
```@docs
149-
extract_realizations
149+
values_as_in_model
150150
```
151151

152152
```@docs

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ export AbstractVarInfo,
9393
getargnames,
9494
generated_quantities,
9595
extract_priors,
96-
extract_realizations,
96+
values_as_in_model,
9797
# Samplers
9898
Sampler,
9999
SampleFromPrior,

src/contexts.jl

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -666,53 +666,54 @@ function fixed(context::FixedContext)
666666
end
667667

668668
"""
669-
RealizationExtractorContext
669+
ValuesAsInModelContext
670670
671-
A context that is used to extract realizations from a model.
671+
A context that is used by [`values_as_in_model`](@ref) to obtain values
672+
of the model parameters as they are in the model.
672673
673674
This is particularly useful when working in unconstrained space, but one
674675
wants to extract the realization of a model in a constrained space.
675676
676677
# Fields
677678
$(TYPEDFIELDS)
678679
"""
679-
struct RealizationExtractorContext{T,C<:AbstractContext} <: AbstractContext
680+
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
680681
"values that are extracted from the model"
681682
values::T
682683
"child context"
683684
context::C
684685
end
685686

686-
RealizationExtractorContext(values) = RealizationExtractorContext(values, DefaultContext())
687-
function RealizationExtractorContext(context::AbstractContext)
688-
return RealizationExtractorContext(OrderedDict(), context)
687+
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
688+
function ValuesAsInModelContext(context::AbstractContext)
689+
return ValuesAsInModelContext(OrderedDict(), context)
689690
end
690691

691-
NodeTrait(::RealizationExtractorContext) = IsParent()
692-
childcontext(context::RealizationExtractorContext) = context.context
693-
function setchildcontext(context::RealizationExtractorContext, child)
694-
return RealizationExtractorContext(context.values, child)
692+
NodeTrait(::ValuesAsInModelContext) = IsParent()
693+
childcontext(context::ValuesAsInModelContext) = context.context
694+
function setchildcontext(context::ValuesAsInModelContext, child)
695+
return ValuesAsInModelContext(context.values, child)
695696
end
696697

697-
function Base.push!(context::RealizationExtractorContext, vn::VarName, value)
698+
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
698699
return setindex!(context.values, copy(value), vn)
699700
end
700701

701-
function broadcast_push!(context::RealizationExtractorContext, vns, values)
702+
function broadcast_push!(context::ValuesAsInModelContext, vns, values)
702703
return push!.((context,), vns, values)
703704
end
704705

705706
# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
706707
function broadcast_push!(
707-
context::RealizationExtractorContext, vns::AbstractVector, values::AbstractMatrix
708+
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
708709
)
709710
for (vn, col) in zip(vns, eachcol(values))
710711
push!(context, vn, col)
711712
end
712713
end
713714

714715
# `tilde_asssume`
715-
function tilde_assume(context::RealizationExtractorContext, right, vn, vi)
716+
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
716717
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
717718
# Save the value.
718719
push!(context, vn, value)
@@ -721,7 +722,7 @@ function tilde_assume(context::RealizationExtractorContext, right, vn, vi)
721722
return value, logp, vi
722723
end
723724
function tilde_assume(
724-
rng::Random.AbstractRNG, context::RealizationExtractorContext, sampler, right, vn, vi
725+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
725726
)
726727
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
727728
# Save the value.
@@ -731,7 +732,7 @@ function tilde_assume(
731732
end
732733

733734
# `dot_tilde_assume`
734-
function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, vi)
735+
function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
735736
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)
736737

737738
# Save the value.
@@ -741,13 +742,7 @@ function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn,
741742
return value, logp, vi
742743
end
743744
function dot_tilde_assume(
744-
rng::Random.AbstractRNG,
745-
context::RealizationExtractorContext,
746-
sampler,
747-
right,
748-
left,
749-
vn,
750-
vi,
745+
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi
751746
)
752747
value, logp, vi = dot_tilde_assume(
753748
rng, childcontext(context), sampler, right, left, vn, vi
@@ -760,10 +755,10 @@ function dot_tilde_assume(
760755
end
761756

762757
"""
763-
extract_realizations(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
764-
extract_realizations(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
758+
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
759+
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
765760
766-
Extract realizations from the `model` for a given `varinfo` through a evaluation of the model.
761+
Get the values of `varinfo` as they would be seen in the model.
767762
768763
If no `varinfo` is provided, then this is effectively the same as
769764
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
@@ -826,27 +821,27 @@ julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
826821
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
827822
false
828823
829-
julia> # Approach 2: Extract realizations using `extract_realizations`.
830-
# (✓) `extract_realizations` will re-run the model and extract
824+
julia> # Approach 2: Extract realizations using `values_as_in_model`.
825+
# (✓) `values_as_in_model` will re-run the model and extract
831826
# the correct realization of `y` given the new values of `x`.
832-
lb ≤ extract_realizations(model, varinfo_linked)[@varname(y)] ≤ ub
827+
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
833828
true
834829
```
835830
"""
836-
function extract_realizations(
831+
function values_as_in_model(
837832
model::Model,
838833
varinfo::AbstractVarInfo=VarInfo(),
839834
context::AbstractContext=DefaultContext(),
840835
)
841-
context = RealizationExtractorContext(context)
836+
context = ValuesAsInModelContext(context)
842837
evaluate!!(model, varinfo, context)
843838
return context.values
844839
end
845-
function extract_realizations(
840+
function values_as_in_model(
846841
rng::Random.AbstractRNG,
847842
model::Model,
848843
varinfo::AbstractVarInfo=VarInfo(),
849844
context::AbstractContext=DefaultContext(),
850845
)
851-
return extract_realizations(model, varinfo, SamplingContext(rng, context))
846+
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
852847
end

test/model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,13 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
376376
end
377377
end
378378

379-
@testset "extract_realizations" begin
379+
@testset "values_as_in_model" begin
380380
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
381381
vns = DynamicPPL.TestUtils.varnames(model)
382382
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
383383
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
384384
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
385-
realizations = extract_realizations(model, varinfo)
385+
realizations = values_as_in_model(model, varinfo)
386386
# Ensure that all variables are found.
387387
vns_found = collect(keys(realizations))
388388
@test vns vns_found == vns vns_found

0 commit comments

Comments
 (0)