Skip to content

CorrBijector makes posterior improper #228

@sethaxen

Description

@sethaxen

An $n \times n$ correlation matrix has ${n \choose 2} = \frac{n (n-1)}{2}$ degrees of freedom. This is the same as the number of elements in a strict upper triangular $n \times n$ matrix. The CorrBijector works by mapping from the correlation matrix first to its unique upper Cholesky factor and then to a strictly upper triangular matrix of unconstrained entries.

The trouble is that in the unconstrained space, we now have $n \times n$ parameters, of which ${n+1 \choose 2} = \frac{n(n+1)}{2}$ have no impact on the log density. These extra parameters have an implicit improper uniform prior on the reals, which makes the posterior distribution in unconstrained space improper. Because these parameters have infinite variance, during adaptation, HMC will learn this, and they will explode in value. I don't know if this will have any negative impact on sampling.

In this demo, we're sampling the uniform distribution on the correlation matrices.

julia> using Turing, Random

julia> @model function model(n, η)
           R ~ LKJ(n, η)
       end;

julia> mod = model(3, 2.0);

julia> Random.seed!(50);

julia> chns = sample(mod, NUTS(0.99), 1_000; save_state=true)
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/51xgc/src/hamiltonian.jl:47
Sampling 100%|█████████████████████████████████████████████████████████████████████████| Time: 0:00:11
Chains MCMC chain (1000×21×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 12.02 seconds
Compute duration  = 12.02 seconds
parameters        = R[1,1], R[2,1], R[3,1], R[1,2], R[2,2], R[3,2], R[1,3], R[2,3], R[3,3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

      R[1,1]    1.0000    0.0000     0.0000    0.0000         NaN       NaN           NaN
      R[2,1]   -0.0086    0.4056     0.0128    0.0131   1010.4052    0.9993       84.0673
      R[3,1]   -0.0081    0.4084     0.0129    0.0113   1034.8774    1.0000       86.1035
      R[1,2]   -0.0086    0.4056     0.0128    0.0131   1010.4052    0.9993       84.0673
      R[2,2]    1.0000    0.0000     0.0000    0.0000   1030.1887    0.9990       85.7133
      R[3,2]   -0.0156    0.4045     0.0128    0.0139    849.6362    1.0033       70.6911
      R[1,3]   -0.0081    0.4084     0.0129    0.0113   1034.8774    1.0000       86.1035
      R[2,3]   -0.0156    0.4045     0.0128    0.0139    849.6362    1.0033       70.6911
      R[3,3]    1.0000    0.0000     0.0000    0.0000   1005.9869    0.9990       83.6997

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      R[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[2,1]   -0.7567   -0.3053   -0.0089    0.2914    0.7302
      R[3,1]   -0.7980   -0.2985   -0.0014    0.2902    0.7525
      R[1,2]   -0.7567   -0.3053   -0.0089    0.2914    0.7302
      R[2,2]    1.0000    1.0000    1.0000    1.0000    1.0000
      R[3,2]   -0.7391   -0.3212   -0.0194    0.3046    0.7256
      R[1,3]   -0.7980   -0.2985   -0.0014    0.2902    0.7525
      R[2,3]   -0.7391   -0.3212   -0.0194    0.3046    0.7256
      R[3,3]    1.0000    1.0000    1.0000    1.0000    1.0000

julia> chns.info.samplerstate.hamiltonian.metric.M⁻¹
9-element Vector{Float64}:
 1.8098190471067061e21
 7.440167311295848e19
 2.0026621564801238e21
 0.24698526419696235
 1.8270061628986394e21
 6.521608457727148e20
 0.27103163446341116
 0.43223828394884933
 3.2173213058057444e20

Note the number of parameters. We should have 3*(3-1)/2 = 3 DOFs, but instead we have 3*3. And note that 3*(3+1)/2=6 of the degrees of freedom have adapted variances of ~1e20.

There are several ways to solve this, neither of which seem to be possible in Bijectors right now:

  • preferably, we map from a vector of length $n \choose 2$ to the correlation matrix, which would restore the bijectivity, but as I understand it Bijectors currently requires that inputs and outputs have the same size.
  • the bijector can be made bijective again by interpreting it as a map $\mathbb{R}^{n^2} \to \mathcal{C} \times \mathbb{R}^{n + 1 \choose 2}$, where $\mathcal{C}$ are the correlation matrices. We then need to put a prior on the $\mathbb{R}^{n + 1 \choose 2}$ outputs, otherwise they default to a uniform prior on the real line. It makes the most sense to give them a normal prior. Because these free parameters are in reality discarded by Bijectors, this would effectively require modifying logabsdetjac to contain this prior term. However, when I tried this, it seemed to have no effect, since I guess logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) is being called instead of logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}). When mapping from X to y, these extra parameters are all set to 0, so we have no way of setting this prior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions