-
Notifications
You must be signed in to change notification settings - Fork 36
Description
For certain distributions, the random variable represented by the Distribution has support which is lower-dimensional than the return-type indicates; that is, the returned realizations are embedded in a higher dimensional space.
For example, LKJ is a distribution over correlation-matrices. Correlation matrices are required to be positive-definite (PD) and have 1 along the diagonal. PD means that we only have (n choose 2) + n degrees of freedom, and 1 along the diagonal removes the additional factor of n, leaving us with only (n choose 2) degrees of freedom. That is, as a vector space, the dimension of the correlation-matrices is actually just (n choose 2), not n × n as might be indicated by the returned Matrix{Float64} from rand(::LKJ)!
For SimpleVarInfo, this is trivial to support because SimpleVarInfo only contains the realizations themselves, no information related to the distributions, etc. Therefore, with something like TuringLang/Bijectors.jl#246, things just work
julia> using DynamicPPL, Distributions, Bijectors
julia> # Switch the bijector used to the `VecCorrBijector` from the forementioned PR.
Bijectors.bijector(::LKJ) = Bijectors.VecCorrBijector();
julia> @model demo() = x ~ LKJ(3, 1);
julia> model = demo();
julia> vi = SimpleVarInfo(model);
julia> # Now it's a matrix.
vi[@varname(x)]
3×3 Matrix{Float64}:
1.0 -0.00803721 -0.849602
-0.00803721 1.0 0.00190424
-0.849602 0.00190424 1.0
julia> vi_transformed = link!!(vi, model);
julia> # Now it's a vector.
vi_transformed[@varname(x)]
3-element Vector{Float64}:
-0.00803738468434096
-1.2547213956880081
-0.0093368799126288
julia> logjoint(model, vi_transformed) # (✓) Works!
-3.515748926181343In contrast, with VarInfo things are not so simple:
julia> vi = VarInfo(model);
julia> vi[@varname(x)]
3×3 Matrix{Float64}:
1.0 0.382085 0.607741
0.382085 1.0 -0.173265
0.607741 -0.173265 1.0
julia> vi_transformed = link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 3 elements to 9 destinations
Stacktrace:
...With VarInfo there are multiple challenges:
link!!occurs in-place and expects the same shape as the original (untransformed) value.getindex(vi, vn, dist)usesreconstruct(dist, val)to reshape the underlying flattened representation inVarInfoto whatdistexpects. This is done before passing it to the bijector/transformation, and so we if we're working with aVector(because we're in transformed space), then callreconstruct(dist, val::Vector)we get back aMatrixaaaand the inverse transformation, which expects aVector, fails. We could start looking into potentially adding the transformation used to thereconstructcall, i.e. letting(dist, transform)-pairs define thereconstructrather than justdist, but then the problem is that inVarInfowhether a variable is transformed or not is decided at runtime, which in turn causes type-instabilities (reconstructwould then returnVectorin some cases andMatrixin others, decided upon at runtime).
So. We need a good way of doing this with VarInfo and I figured I'd make an issue so we can discuss this in more detail together.