Skip to content

Commit af65bb0

Browse files
committed
Merge remote-tracking branch 'origin/master' into release-0.35
2 parents a34fb04 + 727da63 commit af65bb0

File tree

5 files changed

+46
-107
lines changed

5 files changed

+46
-107
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1515
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1616
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1717
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18+
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19+
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
20+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1821
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1922
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2023
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -55,6 +58,9 @@ Compat = "4"
5558
ConstructionBase = "1.5.4"
5659
Distributions = "0.25"
5760
DocStringExtensions = "0.9"
61+
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62+
# for why KernelAbstractions is pinned like this.
63+
KernelAbstractions = "< 0.9.32"
5864
EnzymeCore = "0.6 - 0.8"
5965
ForwardDiff = "0.10"
6066
JET = "0.9"

src/values_as_in_model.jl

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space.
1919
# Fields
2020
$(TYPEDFIELDS)
2121
"""
22-
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
22+
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
2323
"values that are extracted from the model"
24-
values::T
24+
values::OrderedDict
2525
"whether to extract variables on the LHS of :="
2626
include_colon_eq::Bool
2727
"child context"
@@ -114,34 +114,32 @@ function dot_tilde_assume(
114114
end
115115

116116
"""
117-
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
118-
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
117+
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
119118
120119
Get the values of `varinfo` as they would be seen in the model.
121120
122-
If no `varinfo` is provided, then this is effectively the same as
123-
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
121+
More specifically, this method attempts to extract the realization _as seen in
122+
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
123+
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
124+
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
125+
working in unconstrained space.
124126
125-
More specifically, this method attempts to extract the realization _as seen in the model_.
126-
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
127-
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
128-
space.
129-
130-
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
131-
of additional model evaluations.
127+
Hence this method is a "safe" way of obtaining realizations in constrained
128+
space at the cost of additional model evaluations.
132129
133130
# Arguments
134131
- `model::Model`: model to extract realizations from.
135132
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
136133
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
137-
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
138-
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
134+
- `context::AbstractContext`: base context to use for the extraction. Defaults
135+
to `DynamicPPL.DefaultContext()`.
139136
140137
# Examples
141138
142139
## When `VarInfo` fails
143140
144-
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
141+
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
142+
and constrained variables.
145143
146144
```jldoctest
147145
julia> using Distributions, StableRNGs
@@ -191,19 +189,10 @@ true
191189
function values_as_in_model(
192190
model::Model,
193191
include_colon_eq::Bool,
194-
varinfo::AbstractVarInfo=VarInfo(),
192+
varinfo::AbstractVarInfo,
195193
context::AbstractContext=DefaultContext(),
196194
)
197195
context = ValuesAsInModelContext(include_colon_eq, context)
198196
evaluate!!(model, varinfo, context)
199197
return context.values
200198
end
201-
function values_as_in_model(
202-
rng::Random.AbstractRNG,
203-
model::Model,
204-
include_colon_eq::Bool,
205-
varinfo::AbstractVarInfo=VarInfo(),
206-
context::AbstractContext=DefaultContext(),
207-
)
208-
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
209-
end

src/varinfo.jl

Lines changed: 14 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
521521
offset = 0
522522

523523
for (idx, vn) in enumerate(vns_both)
524-
# `idcs`
525524
idcs[vn] = idx
526-
# `vns`
527525
push!(vns, vn)
528-
if vn in vns_left && vn in vns_right
529-
# `vals`: only valid if they're the length.
530-
vals_left = getindex_internal(metadata_left, vn)
531-
vals_right = getindex_internal(metadata_right, vn)
532-
@assert length(vals_left) == length(vals_right)
533-
append!(vals, vals_right)
534-
# `ranges`
535-
r = (offset + 1):(offset + length(vals_left))
536-
push!(ranges, r)
537-
offset = r[end]
538-
# `dists`: only valid if they're the same.
539-
dist_right = getdist(metadata_right, vn)
540-
# Give precedence to `metadata_right`.
541-
push!(dists, dist_right)
542-
gid = metadata_right.gids[getidx(metadata_right, vn)]
543-
push!(gids, gid)
544-
# `orders`: giving precedence to `metadata_right`
545-
push!(orders, getorder(metadata_right, vn))
546-
# `flags`
547-
for k in keys(flags)
548-
# Using `metadata_right`; should we?
549-
push!(flags[k], is_flagged(metadata_right, vn, k))
550-
end
551-
elseif vn in vns_left
552-
# Just extract the metadata from `metadata_left`.
553-
# `vals`
554-
vals_left = getindex_internal(metadata_left, vn)
555-
append!(vals, vals_left)
556-
# `ranges`
557-
r = (offset + 1):(offset + length(vals_left))
558-
push!(ranges, r)
559-
offset = r[end]
560-
# `dists`
561-
dist_left = getdist(metadata_left, vn)
562-
push!(dists, dist_left)
563-
gid = metadata_left.gids[getidx(metadata_left, vn)]
564-
push!(gids, gid)
565-
# `orders`
566-
push!(orders, getorder(metadata_left, vn))
567-
# `flags`
568-
for k in keys(flags)
569-
push!(flags[k], is_flagged(metadata_left, vn, k))
570-
end
571-
else
572-
# Just extract the metadata from `metadata_right`.
573-
# `vals`
574-
vals_right = getindex_internal(metadata_right, vn)
575-
append!(vals, vals_right)
576-
# `ranges`
577-
r = (offset + 1):(offset + length(vals_right))
578-
push!(ranges, r)
579-
offset = r[end]
580-
# `dists`
581-
dist_right = getdist(metadata_right, vn)
582-
push!(dists, dist_right)
583-
gid = metadata_right.gids[getidx(metadata_right, vn)]
584-
push!(gids, gid)
585-
# `orders`
586-
push!(orders, getorder(metadata_right, vn))
587-
# `flags`
588-
for k in keys(flags)
589-
push!(flags[k], is_flagged(metadata_right, vn, k))
590-
end
526+
metadata_for_vn = vn in vns_right ? metadata_right : metadata_left
527+
528+
val = getindex_internal(metadata_for_vn, vn)
529+
append!(vals, val)
530+
r = (offset + 1):(offset + length(val))
531+
push!(ranges, r)
532+
offset = r[end]
533+
dist = getdist(metadata_for_vn, vn)
534+
push!(dists, dist)
535+
gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
536+
push!(gids, gid)
537+
push!(orders, getorder(metadata_for_vn, vn))
538+
for k in keys(flags)
539+
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
591540
end
592541
end
593542

test/model.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
429429
end
430430
end
431431
end
432-
433-
@testset "check that sampling obeys rng if passed" begin
434-
@model function f()
435-
x ~ Normal(0)
436-
return y ~ Normal(x)
437-
end
438-
model = f()
439-
# Call values_as_in_model with the rng
440-
values = values_as_in_model(Random.Xoshiro(43), model, false)
441-
# Check that they match the values that would be used if vi was seeded
442-
# with that seed instead
443-
expected_vi = VarInfo(Random.Xoshiro(43), model)
444-
for vn in keys(values)
445-
@test values[vn] == expected_vi[vn]
446-
end
447-
end
448432
end
449433

450434
@testset "Erroneous model call" begin

test/varinfo.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,17 @@ end
869869
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
870870
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
871871
end
872+
873+
# The below used to error, testing to avoid regression.
874+
@testset "merge different dimensions" begin
875+
vn = @varname(x)
876+
vi_single = VarInfo()
877+
vi_single = push!!(vi_single, vn, 1.0, Normal())
878+
vi_double = VarInfo()
879+
vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
880+
@test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
881+
@test merge(vi_double, vi_single)[vn] == 1.0
882+
end
872883
end
873884

874885
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)