Skip to content

Commit b970a6c

Browse files
Allow combination of components with different numbers of observed states
1 parent a70b733 commit b970a6c

File tree

4 files changed

+583
-8
lines changed

4 files changed

+583
-8
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation
1515
from pymc_extras.statespace.models.utilities import (
16+
add_tensors_by_dim_labels,
1617
conform_time_varying_and_time_invariant_matrices,
18+
join_tensors_by_dim_labels,
1719
make_default_coords,
1820
)
1921
from pymc_extras.statespace.utils.constants import (
@@ -481,11 +483,13 @@ def populate_component_properties(self):
481483
def _get_combined_shapes(self, other):
482484
k_states = self.k_states + other.k_states
483485
k_posdef = self.k_posdef + other.k_posdef
484-
if self.k_endog != other.k_endog:
485-
raise NotImplementedError(
486-
"Merging elements with different numbers of observed states is not supported."
486+
if self.k_endog == other.k_endog:
487+
k_endog = self.k_endog
488+
else:
489+
combined_states = self._combine_property(
490+
other, "observed_state_names", allow_duplicates=False
487491
)
488-
k_endog = self.k_endog
492+
k_endog = len(combined_states)
489493

490494
return k_states, k_posdef, k_endog
491495

@@ -499,6 +503,9 @@ def make_slice(name, x, o_x):
499503
self_matrices = [self.ssm[name] for name in LONG_MATRIX_NAMES]
500504
other_matrices = [other.ssm[name] for name in LONG_MATRIX_NAMES]
501505

506+
self_observed_states = self.observed_state_names
507+
other_observed_states = other.observed_state_names
508+
502509
x0, P0, c, d, T, Z, R, H, Q = (
503510
self.ssm[make_slice(name, x, o_x)]
504511
for name, x, o_x in zip(LONG_MATRIX_NAMES, self_matrices, other_matrices)
@@ -517,19 +524,33 @@ def make_slice(name, x, o_x):
517524
state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c))
518525
state_intercept.name = c.name
519526

520-
obs_intercept = d + o_d
527+
obs_intercept = add_tensors_by_dim_labels(
528+
d, o_d, labels=self_observed_states, other_labels=other_observed_states, labeled_axis=-1
529+
)
521530
obs_intercept.name = d.name
522531

523532
transition = pt.linalg.block_diag(T, o_T)
524533
transition.name = T.name
525534

526-
design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1)
535+
design = join_tensors_by_dim_labels(
536+
*conform_time_varying_and_time_invariant_matrices(Z, o_Z),
537+
labels=self_observed_states,
538+
other_labels=other_observed_states,
539+
labeled_axis=-2,
540+
join_axis=-1,
541+
)
527542
design.name = Z.name
528543

529544
selection = pt.linalg.block_diag(R, o_R)
530545
selection.name = R.name
531546

532-
obs_cov = H + o_H
547+
obs_cov = add_tensors_by_dim_labels(
548+
H,
549+
o_H,
550+
labels=self_observed_states,
551+
other_labels=other_observed_states,
552+
labeled_axis=(-1, -2),
553+
)
533554
obs_cov.name = H.name
534555

535556
state_cov = pt.linalg.block_diag(Q, o_Q)

pymc_extras/statespace/models/utilities.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from typing import cast as type_cast
2+
13
import numpy as np
24
import pytensor.tensor as pt
35

6+
from pytensor.tensor import TensorVariable
7+
48
from pymc_extras.statespace.utils.constants import (
59
ALL_STATE_AUX_DIM,
610
ALL_STATE_DIM,
@@ -374,6 +378,258 @@ def conform_time_varying_and_time_invariant_matrices(A, B):
374378
return A, B
375379

376380

381+
def normalize_axis(x, axis):
382+
"""
383+
Convert negative axis values to positive axis values
384+
"""
385+
if isinstance(axis, tuple):
386+
return tuple([normalize_axis(x, i) for i in axis])
387+
if axis < 0:
388+
axis = x.ndim + axis
389+
return axis
390+
391+
392+
def reorder_from_labels(
393+
x: TensorVariable,
394+
labels: list[str],
395+
ordered_labels: list[str],
396+
labeled_axis: int | tuple[int, int],
397+
) -> TensorVariable:
398+
"""
399+
Reorder an input tensor along request axis/axes based on lists of string labels
400+
401+
Parameters
402+
----------
403+
x: TensorVariable
404+
Input tensor
405+
labels: list of str
406+
Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have
407+
``x.shape[labeled_axis] == len(labels)``
408+
ordered_labels: list of str
409+
Target ordering according to which ``x`` will be reordered.
410+
labeled_axis: int or tuple of int
411+
Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and
412+
and reorganization will be done on all requested axes together (NOT fancy indexing!)
413+
414+
Returns
415+
-------
416+
x_sorted: TensorVariable
417+
Output tensor sorted along ``labeled_axis`` according to ``ordered_labels``
418+
"""
419+
n_out = len(ordered_labels)
420+
label_to_index = {label: index for index, label in enumerate(ordered_labels)}
421+
422+
missing_labels = [label for label in ordered_labels if label not in labels]
423+
indices = np.argsort([label_to_index[label] for label in [*labels, *missing_labels]])
424+
425+
if isinstance(labeled_axis, int):
426+
labeled_axis = (labeled_axis,)
427+
428+
if indices.tolist() != list(range(n_out)):
429+
for axis in labeled_axis:
430+
idx = np.s_[tuple([slice(None, None) if i != axis else indices for i in range(x.ndim)])]
431+
x = x[idx]
432+
433+
return x
434+
435+
436+
def pad_and_reorder(
437+
x: TensorVariable, labels: list[str], ordered_labels: list[str], labeled_axis: int
438+
) -> TensorVariable:
439+
"""
440+
Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the
441+
padded dimension to match the ordering in ``ordered_labels``.
442+
443+
Parameters
444+
----------
445+
x: TensorVariable
446+
Input tensor
447+
labels: list of str
448+
String labels associated with the `x` tensor at the ``labeled_axis`` dimension. At runtime, should have
449+
``x.shape[labeled_axis] == len(labels)``. ``labels`` should be a subset of ``ordered_labels``.
450+
ordered_labels: list of str
451+
Target ordering according to which ``x`` will be reordered.
452+
labeled_axis: int
453+
Axis along which ``x`` will be labeled.
454+
455+
Returns
456+
-------
457+
x_padded: TensorVariable
458+
Output tensor padded along ``labeled_axis`` according to ``ordered_labels``, then reordered.
459+
460+
"""
461+
n_out = len(ordered_labels)
462+
n_missing = n_out - len(labels)
463+
464+
if n_missing > 0:
465+
zeros = pt.zeros(
466+
tuple([x.shape[i] if i != labeled_axis else n_missing for i in range(x.ndim)])
467+
)
468+
x_padded = pt.concatenate([x, zeros], axis=labeled_axis)
469+
else:
470+
x_padded = x
471+
472+
return reorder_from_labels(x_padded, labels, ordered_labels, labeled_axis)
473+
474+
475+
def ndim_pad_and_reorder(
476+
x: TensorVariable,
477+
labels: list[str],
478+
ordered_labels: list[str],
479+
labeled_axis: int | tuple[int, int],
480+
) -> TensorVariable:
481+
"""
482+
Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the
483+
padded dimension to match the ordering in ``ordered_labels``.
484+
485+
Unlike ``pad_and_reorder``, this function allows padding and reordering to be done simultaneously on multiple
486+
axes. In this case, reordering is done jointly on all axes -- it does *not* use fancy indexing.
487+
488+
Parameters
489+
----------
490+
x: TensorVariable
491+
Input tensor
492+
labels: list of str
493+
Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have
494+
``x.shape[labeled_axis] == len(labels)``. If ``labeled_axis`` is a tuple, all axes are assumed to have the
495+
same labels.
496+
ordered_labels: list of str
497+
Target ordering according to which ``x`` will be reordered. ``labels`` should be a subset of ``ordered_labels``.
498+
labeled_axis: int or tuple of int
499+
Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and
500+
and reorganization will be done on all requested axes together (NOT fancy indexing!)
501+
502+
Returns
503+
-------
504+
x_sorted: TensorVariable
505+
Output tensor. Each ``labeled_axis`` is padded to the length of ``ordered_labels``, then reordered.
506+
"""
507+
n_missing = len(ordered_labels) - len(labels)
508+
509+
if isinstance(labeled_axis, int):
510+
labeled_axis = (labeled_axis,)
511+
512+
if n_missing > 0:
513+
pad_size = [(0, 0) if i not in labeled_axis else (0, n_missing) for i in range(x.ndim)]
514+
x = pt.pad(x, pad_size, mode="constant", constant_values=0)
515+
516+
return reorder_from_labels(x, labels, ordered_labels, labeled_axis)
517+
518+
519+
def add_tensors_by_dim_labels(
520+
tensor: TensorVariable,
521+
other_tensor: TensorVariable,
522+
labels: list[str],
523+
other_labels: list[str],
524+
labeled_axis: int | tuple[int, int] = -1,
525+
) -> TensorVariable:
526+
"""
527+
Add two tensors based on labels associated with one dimension.
528+
529+
When combining statespace matrices associated with structural components with potentially different states, it is
530+
important to make sure that duplicated states are handled correctly. For bias vectors and covariance matrices,
531+
duplicated states should be summed.
532+
533+
When a state appears in one component but not another, that state should be treated as an implicit zero in the
534+
components where the state does not appear. This amounts to padding the relevant matrices with zeros before
535+
performing the addition.
536+
537+
When labeled_axis is a tuple, each provided label is assumed to be identically labeled in each input tensor. This
538+
is the case, for example, when working with a covariance matrix. In this case, padding and alignment will be
539+
done on each indicated index.
540+
541+
Parameters
542+
----------
543+
tensor: TensorVariable
544+
A statespace matrix to be summed with ``other_matrix``.
545+
other_tensor: TensorVariable
546+
A statespace matrix to be summed with ``matrix``.
547+
labels: list of str
548+
Dimension labels associated with ``matrix``, on the ``labeled_axis`` dimension.
549+
other_labels: list of str
550+
Dimension labels associated with ``other_matrix``, on the ``labeled_axis`` dimension.
551+
labeled_axis: int or tuple of int
552+
Dimension that is labeled by ``labels`` and ``other_labels``. ``matrix.shape[labeled_axis]`` must have the
553+
shape of ``len(labels)`` at runtime.
554+
555+
Returns
556+
-------
557+
result: TensorVariable
558+
Result of addition of ``matrix`` and ``other_matrix``, along the ``labeled_axis`` dimension. The ordering of
559+
the output will be ``labels + [label for label in other_labels if label not in labels]``. That is, ``labels``
560+
come first, followed by any new labels introduced by ``other_labels``.
561+
562+
"""
563+
labeled_axis = normalize_axis(tensor, labeled_axis)
564+
new_labels = [label for label in other_labels if label not in labels]
565+
combined_labels = type_cast(list[str], [*labels, *new_labels])
566+
567+
# If there is no overlap at all, directly concatenate the two matrices -- there's no need to worry about the order
568+
# of things, or padding. This is equivalent to padding both out with zeros then adding them.
569+
if combined_labels == [*labels, *other_labels]:
570+
if isinstance(labeled_axis, int):
571+
return pt.concatenate([tensor, other_tensor], axis=labeled_axis)
572+
else:
573+
# In the case where we want to align multiple dimensions, use block_diag to accomplish padding on the last
574+
# two dimensions
575+
dims = [*[i for i in range(tensor.ndim) if i not in labeled_axis], *labeled_axis]
576+
return pt.linalg.block_diag(
577+
type_cast(TensorVariable, tensor.transpose(*dims)),
578+
type_cast(TensorVariable, other_tensor.transpose(*dims)),
579+
)
580+
# Otherwise, there are two possibilities. If all labels are the same, we might need to re-order one or both to get
581+
# them to agree. If *some* labels are the same, we will need to pad first, then potentially re-order. In any case,
582+
# the final step is just to add the padded and re-ordered tensors.
583+
fn = pad_and_reorder if isinstance(labeled_axis, int) else ndim_pad_and_reorder
584+
585+
padded_tensor = fn(
586+
tensor,
587+
labels=type_cast(list[str], labels),
588+
ordered_labels=combined_labels,
589+
labeled_axis=labeled_axis,
590+
)
591+
padded_tensor.name = tensor.name
592+
593+
padded_other_tensor = fn(
594+
other_tensor,
595+
labels=type_cast(list[str], other_labels),
596+
ordered_labels=combined_labels,
597+
labeled_axis=labeled_axis,
598+
)
599+
600+
padded_other_tensor.name = other_tensor.name
601+
602+
return padded_tensor + padded_other_tensor
603+
604+
605+
def join_tensors_by_dim_labels(
606+
tensor: TensorVariable,
607+
other_tensor: TensorVariable,
608+
labels: list[str],
609+
other_labels: list[str],
610+
labeled_axis: int = -1,
611+
join_axis: int = -1,
612+
block_diag_join: bool = False,
613+
) -> TensorVariable:
614+
labeled_axis = normalize_axis(tensor, labeled_axis)
615+
new_labels = [label for label in other_labels if label not in labels]
616+
combined_labels = [*labels, *new_labels]
617+
618+
# Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros
619+
# everywhere they are needed -- no other sorting or padding necessary
620+
if combined_labels == [*labels, *other_labels]:
621+
return pt.linalg.block_diag(tensor, other_tensor)
622+
623+
# Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out.
624+
tensor = ndim_pad_and_reorder(tensor, labels, combined_labels, labeled_axis)
625+
other_tensor = ndim_pad_and_reorder(other_tensor, other_labels, combined_labels, labeled_axis)
626+
627+
if block_diag_join:
628+
return pt.linalg.block_diag(tensor, other_tensor)
629+
else:
630+
return pt.concatenate([tensor, other_tensor], axis=join_axis)
631+
632+
377633
def get_exog_dims_from_idata(exog_name, idata):
378634
if exog_name in idata.posterior.data_vars:
379635
exog_dims = idata.posterior[exog_name].dims[2:]

0 commit comments

Comments
 (0)