Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ x ~ product_distribution(Normal.(y))
x ~ MvNormal(fill(0.0, 2), I)
```

This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as
This is often more performant as well.

The new implementation of `x .~ ...` is just a short-hand for `x ~ filldist(...)`, which means that `x` will be seen as a single multivariate variable. In most cases this does not change anything for the user, with the one notable exception being `pointwise_loglikelihoods`, which previously treated `.~` assignments as assigning multiple univariate variables. If you _do_ want a variable to be seen as an array of univariate variables rather than a single multivariate variable, you can always expand into a loop, such as
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to clearly state that each .~ and ~ defines a single random variable. For more flexible condition'ing, I think you could try to support model | x = [missing, 1., 2., missing], which would allow users to condition on a subset of elements in x but still treat x as a single random variable. Again, please document this clearly in breaking changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The new implementation of `x .~ ...` is just a short-hand for `x ~ filldist(...)`, which means that `x` will be seen as a single multivariate variable. In most cases this does not change anything for the user, with the one notable exception being `pointwise_loglikelihoods`, which previously treated `.~` assignments as assigning multiple univariate variables. If you _do_ want a variable to be seen as an array of univariate variables rather than a single multivariate variable, you can always expand into a loop, such as
The new implementation of `x .~ ...` is just a short-hand for `x ~ product_distribution(...)`, which means that `x` will be seen as a single multivariate variable. In most cases this does not change anything for the user, with the one notable exception being `pointwise_loglikelihoods`, which previously treated `.~` assignments as assigning multiple univariate variables. If you _do_ want a variable to be seen as an array of univariate variables rather than a single multivariate variable, you can always expand into a loop, such as


```julia
dists = Normal.(y)
Expand All @@ -54,7 +56,7 @@ for i in 1:length(dists)
end
```

Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example,
Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must always be replaced with a loop. For example,

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
Expand All @@ -70,8 +72,6 @@ for i in 1:3, j in 1:4
end
```

This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side.

### Remove indexing by samplers

This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using AbstractPPL
using Bijectors
using Compat
using Distributions
using DistributionsAD: filldist
using OrderedCollections: OrderedCollections, OrderedDict

using AbstractMCMC: AbstractMCMC
Expand Down
7 changes: 2 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,10 @@ end
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
@gensym dist left_axes idx
@gensym dist
return quote
$dist = DynamicPPL.check_dot_tilde_rhs($right)
$left_axes = axes($left)
for $idx in Iterators.product($left_axes...)
$left[$idx...] ~ $dist
end
$left ~ DynamicPPL.filldist($dist, Base.size($left)...)
end
end

Expand Down
9 changes: 4 additions & 5 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ y .~ Normal(μ, σ)
y ~ MvNormal(fill(μ, n), σ^2 * I)
```

In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
while in (3) `y` will be treated as a _single_ n-dimensional observation.
In (1) `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
while in (2) and (3) `y` will be treated as a _single_ n-dimensional observation.

This is important to keep in mind, in particular if the computation is used
for downstream computations.
Expand Down Expand Up @@ -216,8 +216,7 @@ OrderedDict{VarName, Matrix{Float64}} with 6 entries:
```

## Broadcasting
Note that `x .~ Dist()` will treat `x` as a collection of
_independent_ observations rather than as a single observation.
Note that `x .~ Dist()` will treat `x` as a single multivariate observation.

```jldoctest; setup = :(using Distributions)
julia> @model function demo(x)
Expand All @@ -226,7 +225,7 @@ julia> @model function demo(x)

julia> m = demo([1.0, ]);

julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])])
julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x)])
-1.4189385332046727

julia> m = demo([1.0; 1.0]);
Expand Down
12 changes: 6 additions & 6 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_observe)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

@model function demo_assume_index_observe(
Expand Down Expand Up @@ -293,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_observe_index)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

# Using vector of `length` 1 here so the posterior of `m` is the same
Expand Down Expand Up @@ -374,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

@model function demo_assume_observe_literal()
Expand Down Expand Up @@ -458,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

@model function _likelihood_multivariate_observe(s, m, x)
Expand Down Expand Up @@ -492,7 +492,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

@model function demo_dot_assume_observe_matrix_index(
Expand Down Expand Up @@ -521,7 +521,7 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)})
return [@varname(s[1]), @varname(s[2]), @varname(m)]
return [@varname(s), @varname(m)]
end

@model function demo_assume_matrix_observe_matrix_index(
Expand Down
9 changes: 2 additions & 7 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,7 @@ end
# Transform only one variable
all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns)
for vn in [
@varname(s),
@varname(m),
@varname(x),
@varname(y),
@varname(x[2]),
@varname(y[2])
@varname(s), @varname(m), @varname(x), @varname(y), @varname(x), @varname(y[2])
]
target_vns = filter(x -> subsumes(vn, x), all_vns)
other_vns = filter(x -> !subsumes(vn, x), all_vns)
Expand Down Expand Up @@ -874,7 +869,7 @@ end
varinfo2 = last(
DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext())
)
for vn in [@varname(x), @varname(y[1])]
for vn in [@varname(x), @varname(y)]
@test DynamicPPL.istrans(varinfo2, vn)
end
end
Expand Down
Loading