-
Notifications
You must be signed in to change notification settings - Fork 36
ProductNamedTupleDistribution compatibility #1079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a63d021
16aca2c
2268acd
3c1aeee
63a75e7
dd512c8
4feb182
e2f7d68
3f31d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -354,7 +354,6 @@ Return the transformation from the vector representation of a realization of siz | |
original representation. | ||
""" | ||
from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz) | ||
# TODO(mhauru) Is the below used? If not, this function can be removed. | ||
from_vec_transform_for_size(::Tuple{<:Any}) = identity | ||
|
||
""" | ||
|
@@ -367,6 +366,60 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) | |
from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() | ||
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) | ||
|
||
struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} | ||
dists::T | ||
# The `i`-th input range corresponds to the segment of the input vector | ||
# that belongs to the `i`-th distribution. | ||
input_ranges::Vector{UnitRange} | ||
Comment on lines
+369
to
+373
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a subpar data structure because it assumes that dists and input_ranges have the same length (this is enforced in the inner constructor). I think in an ideal world we would combine dists and input_ranges into a single NamedTuple ... but the issue with that is that to make it type stable I think we'd have to make the constructor also a generated function. |
||
function ProductNamedTupleUnvecTransform( | ||
d::Distributions.ProductNamedTupleDistribution{names} | ||
) where {names} | ||
offset = 1 | ||
input_ranges = UnitRange[] | ||
for name in names | ||
this_dist = d.dists[name] | ||
this_name_size = _input_length(from_vec_transform(this_dist)) | ||
push!(input_ranges, offset:(offset + this_name_size - 1)) | ||
offset += this_name_size | ||
end | ||
return new{names,typeof(d.dists)}(d.dists, input_ranges) | ||
end | ||
end | ||
penelopeysm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@generated function (trf::ProductNamedTupleUnvecTransform{names})( | ||
x::AbstractVector | ||
) where {names} | ||
expr = Expr(:tuple) | ||
for (i, name) in enumerate(names) | ||
push!( | ||
expr.args, | ||
:($name = from_vec_transform(trf.dists.$name)(x[trf.input_ranges[$i]])), | ||
) | ||
end | ||
return expr | ||
end | ||
penelopeysm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution) | ||
return ProductNamedTupleUnvecTransform(dist) | ||
end | ||
function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x) | ||
return f(x), zero(LogProbType) | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
penelopeysm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# This function returns the length of the vector that the function from_vec_transform | ||
# expects. This helps us determine which segment of a concatenated vector belongs to which | ||
# variable. | ||
_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1 | ||
_input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size) | ||
function _input_length(trfm::ProductNamedTupleUnvecTransform) | ||
return sum(_input_length ∘ from_vec_transform, values(trfm.dists)) | ||
end | ||
function _input_length( | ||
c::ComposedFunction{<:DynamicPPL.ToChol,<:DynamicPPL.ReshapeTransform} | ||
) | ||
return _input_length(c.inner) | ||
end | ||
|
||
""" | ||
from_vec_transform(f, size::Tuple) | ||
|
||
|
@@ -405,7 +458,9 @@ function from_linked_vec_transform(dist::UnivariateDistribution) | |
sz = Bijectors.output_size(f_combined, size(dist)) | ||
return UnwrapSingletonTransform(sz) ∘ f_combined | ||
end | ||
|
||
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution) | ||
return invlink_transform(dist) | ||
end | ||
# Specializations that circumvent the `from_vec_transform` machinery. | ||
function from_linked_vec_transform(dist::LKJCholesky) | ||
return inverse(Bijectors.VecCholeskyBijector(dist.uplo)) | ||
|
Uh oh!
There was an error while loading. Please reload this page.