Skip to content

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Oct 20, 2025

Closes #1077
Closes TuringLang/Turing.jl#2659

This PR implements the necessary methods to get ProductNamedTupleDistribution working in Turing. I've tested with this PR + TuringLang/Turing.jl#2689 and this model works with both MH (i.e. unlinked) and NUTS (i.e. linked). It contains the absolute most nightmare scenario where there is a nested PNTDist plus LKJCholesky inside a PNTDist.

using Turing

@model function f()
    d = product_distribution((s = Normal(), m = InverseGamma(2, 3)))
    x ~ d
    a ~ product_distribution((s = LKJCholesky(3, 0.5), x = d))
    y ~ Normal(a.x.s, a.x.m)
    return x, a, y
end

sample(f(), MH(), 1000)
sample(f(), NUTS(), 1000)

This is all type stable too:

d = product_distribution((s = Normal(), m = InverseGamma(2, 3)))
x = rand(d); y = DynamicPPL.tovec(x)
# @inferred DynamicPPL.to_vec_transform(d)(x) # Doesn't work
@inferred DynamicPPL.tovec(x)
@inferred DynamicPPL.from_vec_transform(d)(y)

d2 = product_distribution((s = LKJCholesky(3, 0.5), d = d))
x = rand(d2); y = DynamicPPL.tovec(x)
# @inferred DynamicPPL.to_vec_transform(d2)(x) # Doesn't work
@inferred DynamicPPL.tovec(x)
@inferred DynamicPPL.from_vec_transform(d2)(y)

This PR also adds tests for all the transform methods for a number of typical distributions.


future work

I think these should be a separate PR, but:

  • I feel kind of uncomfortable with the whole from_vec_transform stuff. I spent way too long trying to figure out the meaning of things. It's IMO not clear what the inputs and outputs are expected to be, and some examples would go a long way.
  • It's also unclear which methods actually need to be overloaded. For example, I had to implement from_vec_transform but I didn't have to implement to_vec_transform because the Metadata/VNV code just skips over it and goes straight to tovec. I think Metadata/VNV code should use to_vec_transform and the default definition of to_vec_transform should internally use _tovec.
  • It's also confusing what methods exist. The codebase sometimes uses from_vec_transform(dist::Distribution) but in particular when handling VarNamedVector it also uses from_vec_transform(x) where x is a sample drawn from dist. I think the latter should be renamed.
  • In fact, I think I got this diagram in the docs wrong: https://turinglang.org/docs/developers/transforms/dynamicppl/#a-deeper-dive-into-dynamicppls-internal-machinery because I tried to look at it to refresh my memory, and I couldn't make any sense of it.

@penelopeysm penelopeysm changed the base branch from main to breaking October 20, 2025 17:59
Copy link
Contributor

DynamicPPL.jl documentation for PR #1079 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1079/

Copy link

codecov bot commented Oct 20, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.06%. Comparing base (4addb5f) to head (854a6cb).

Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1079      +/-   ##
============================================
- Coverage     81.13%   81.06%   -0.08%     
============================================
  Files            40       40              
  Lines          3722     3749      +27     
============================================
+ Hits           3020     3039      +19     
- Misses          702      710       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 20, 2025

Here are some docs, to be used in a future PR:

"""
DynamicPPL transforms
---------------------

In general, there are four ways that parameters are internally represented:

1. **Unlinked**: The default, unchanged, representation of the parameter, as returned by
   `rand(dist)`.

2. **Unlinked vectorised**: A flattened vector form of the original representation.

3. **Linked**: A transformed version of the original representation. This is accomplished
   by applying a bijective transformation (a 'link function') to the unlinked parameter.
   This results in a parameter that exists in unconstrained Euclidean space.

4. **Linked vectorised**: A flattened vector form of the linked representation.

## Example 1

Consider the variable `x ~ LogNormal()`.

1. If we draw a sample from this distribution, it is necessarily positive, so the unlinked
   representation is a positive Float64 value (e.g. `2.0`).
2. The unlinked vectorised representation is a one-element vector containing that positive
   Float64 value (e.g. `[2.0]`).
3. For `LogNormal()`, the appropriate link function is the logarithm, which maps positive
   values to the entire real line. Thus, the linked representation is `log(2.0) ≈ 0.693`.
4. The linked vectorised representation is a one-element vector containing that real value
    (e.g. `[0.693]`).

## Example 2

Consider now `y ~ LKJCholesky(2, 0.5)`.

1. A sample from this distribution is a `LinearAlgebra.Cholesky` object which wraps a
   2x2 lower-triangular matrix. The top-left element is always 1.0, and the bottom row
   must satisfy y.L[2, 1]^2 + y.L[2, 2]^2 = 1.

2. The unlinked vectorised representation is a vectorised form of this lower-triangular
   matrix. In principle, y.L[1, 1] (which is always one) and y.L[1, 2] (which is always
   zero) could be dropped to save space, but DynamicPPL's implementation retains it. This
   Thus, the unlinked vectorised representation is a 4-element vector `[1.0, 0.0, y.L[2, 1],
   y.L[2, 2]]`.

3. The link function for `LKJCholesky` maps the Cholesky sample to a vector of unconstrained
   real values. This is accomplished using Bijectors.VecCholeskyBijector. The implementation
   of this is not important here, but it is worth pointing out that there is only one free
   parameter in this 2x2 Cholesky factor (since the top row is fixed, and the two elements in
   the lower row are interdependent). Thus, the linked representation is a vector with one
   value.

4. Since the linked representation is already a vector, the linked vectorised representation
   is identical to it.

## Transform functions

This file defines functions which map between these different representations. Specifically:
                                                                          

                             linked vectorised                                     
                                                                                   
                                    ▲ │                                            
                                    │ │                                            
          from_linked_vec_transform │ │ to_linked_vec_transform                    
                                    │ │                                            
                                    │ ▼                                            
                                                  link_transform                   
                                 unlinked     ──────────────────────►      linked  
                                              ◄──────────────────────              
                                    ▲ │          invlink_transform                 
                                    │ │                                            
                 from_vec_transform │ │ to_vec_transform                           
                                    │ │                                            
                                    │ ▼                                            
                                                                                   
                             unlinked vectorised                                   

Note that all of these `..._transform` functions do not actually perform the transformation.
Instead, when called with a distribution `dist` as the only argument, they return a new
function which _then_ performs the transformation.

For example, if `dist = LogNormal()` and `x = 2.0`, then _in principle_ the following should
hold:

- `to_vec_transform(dist)(x) == [2.0]`
- `link_transform(dist)(x) == log(2.0)`
- `to_linked_vec_transform(dist)(x) == [log(2.0)]

## Implementing a new distribution

To implement the necessary transformations for a new distribution `dist`, you need to do the
following:

1. `link_transform` and `invlink_transform` should be implemented via Bijectors.jl; there is
   no need to add custom functionality for this in DynamicPPL. In there you should define `b
   = bijector(dist)`, along with `with_logabsdet_jacobian(b, x)` where `x` is in the support
   of `dist`. You also need to define `inverse(b)` and `with_logabsdet_jacobian(inverse(b),
   y)`.

2. Implement `from_vec_transform(dist)`. It turns out that although `to_vec_transform`
   exists as a function which can be overloaded, DynamicPPL does not actually ever call it,
   so you do not need to implement it.

3. Implement `from_linked_vec_transform(dist)`. There are three general cases here, and your
   distribution will most likely fall into one of them:

   - If `dist` is univariate, then the linked representation is a scalar, and the linked
     vectorised representation is a one-element vector. In this case, you can implement
     `from_linked_vec_transform(dist)` as the composition of `invlink_transform(dist)` and
     `only`.

   - If the linked representation is already a vector (which is always the case for
     multivariate distributions), then `from_linked_vec_transform(dist) =
     invlink_transform(dist)`. This follows because the linked vectorised representation
     is already the same as the linked representation. For example, this is the case for
     `LKJCholesky`.

   - If the linked representation is something else (e.g. a matrix) then you will need to
     implement a custom `from_linked_vec_transform(dist)`, which first reshapes the linked
     vectorised representation back into the linked representation, and then applies
     `invlink_transform(dist)`.

## Jacobians

The above is slightly misleading, because the actual requirement is not to implement the
transformation functions `trf(x)` for the transformation returned by `from_vec_transform`
and friends. In fact, what DynamicPPL really needs is an implementation of
`Bijectors.with_logabsdet_jacobian(trf, x)` for the transformation `trf`. This requirement
means that it is often more convenient to implement `trf` as a callable struct, since this
allows for convenient dispatch on its type.

For simple vectorisation transforms like the outputs of `from_vec_transform`, the Jacobian
term is zero, so often `Bijectors.with_logabsdet_jacobian(trf, x)` just returns `trf(x),
zero(LogProbType)`. For the results of `from_linked_vec_transform`, the Jacobian term will
be the same as that of the link transformation. In general, you can often rely on this
to be automatically computed if you define `from_linked_vec_transform` as a composition
of other transformations for which `Bijectors.with_logabsdet_jacobian` is already defined.
"""

@penelopeysm penelopeysm marked this pull request as ready for review October 21, 2025 12:07
@penelopeysm penelopeysm requested a review from mhauru October 21, 2025 12:51
@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 21, 2025

CI failures are unrelated, I'll fix investigate them on main...

Edit: #1081

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. I agree that all the from/to_ transform stuff gets confusing.

@penelopeysm penelopeysm requested a review from mhauru October 21, 2025 16:29
@penelopeysm
Copy link
Member Author

Bonus: the file is editable in nvim again

Comment on lines +369 to +373
struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
dists::T
# The `i`-th input range corresponds to the segment of the input vector
# that belongs to the `i`-th distribution.
input_ranges::Vector{UnitRange}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a subpar data structure because it assumes that dists and input_ranges have the same length (this is enforced in the inner constructor). I think in an ideal world we would combine dists and input_ranges into a single NamedTuple ... but the issue with that is that to make it type stable I think we'd have to make the constructor also a generated function.

Base automatically changed from breaking to main October 21, 2025 17:06
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice.

@penelopeysm penelopeysm merged commit ca500e2 into main Oct 21, 2025
3 of 17 checks passed
@penelopeysm penelopeysm deleted the py/pntdist branch October 21, 2025 17:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement to_vec_transform for ProductNamedTupleDistribution Support for Distributions.ProductNamedTupleDistribution

2 participants