Constructing block-banded matrices #20679
-
Background: To avoid the XY problem, first a bit of background in case I'm running down the wrong rabbit hole. I am working with state space models of the form I would like to evaluate the posterior distribution of The precision matrix, has a banded structure because the process is a Markov process. Specifically, the log density is (up to constants independent of where where the two terms on the diagonal account for The JAX question: Now to the main part of the question: How can I construct this matrix efficiently (including handling batch dimensions). In practice, I have Attempts: This feels like it should be achievable using a convolution where the rectangular block Thank you for the great package and reading this far! Footnotes
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Alright, figured it out in case it's useful for anyone. It involves constructing the rectangular block def state_space_precision(A, innovation_precision, n):
"""
Evaluate the state space precision.
Args:
A: Transition matrix with shape `(..., p, p)`, where `p` is the dimensionality
of the state space.
innovation_precision: Precision of innovation noise with shape `(..., p, p)`.
n: Number of steps.
Returns:
Precision of the state with shape `(..., n * p, n * p)`.
"""
# Broadcast arrays and reshape to have a single batch dimension. We'll restore it
# later.
A, innovation_precision = jnp.broadcast_arrays(A, innovation_precision)
*batch_shape, p, _ = A.shape
A = A.reshape((-1, p, p))
innovation_precision = innovation_precision.reshape((-1, p, p))
# Evaluate one of the blocks of the precision matrix which we'll pad and roll.
offdiag_block = - jax.lax.batch_matmul(A.mT, innovation_precision)
diag_block = - jax.lax.batch_matmul(offdiag_block, A) + innovation_precision
row_block = jnp.concatenate([offdiag_block.mT, diag_block, offdiag_block], axis=-1)
# Pad for rolling, vmap the rolls for efficiency, and discard the part we don't need.
result = jnp.pad(row_block, ((0, 0), (0, 0), (0, (n - 1) * p)))
result = jax.vmap(lambda shift: jnp.roll(result, shift * p, axis=-1))(jnp.arange(n))[..., p:-p]
# Move the rolled dimension to the right position and reshape to get a batch of square matrices.
result = jnp.moveaxis(result, 0, -3).reshape((-1, n * p, n * p))
# Set the last element to the innovation precision because there are no subsequent
# samples.
result = result.at[..., -p:, -p:].set(innovation_precision)
# Restore the old batch shape.
result = result.reshape((*batch_shape, n * p, n * p))
return result |
Beta Was this translation helpful? Give feedback.
Alright, figured it out in case it's useful for anyone. It involves constructing the rectangular block$\left(-C^{-1} A, A^\intercal C^{-1} A + I, -A^\intercal C^{-1}\right)$ , padding it with zeros on the right,
vmap
-ing over shifts for each row, and some reshaping.