|
| 1 | +from typing import cast as type_cast |
| 2 | + |
1 | 3 | import numpy as np
|
2 | 4 | import pytensor.tensor as pt
|
3 | 5 |
|
| 6 | +from pytensor.tensor import TensorVariable |
| 7 | + |
4 | 8 | from pymc_extras.statespace.utils.constants import (
|
5 | 9 | ALL_STATE_AUX_DIM,
|
6 | 10 | ALL_STATE_DIM,
|
@@ -374,6 +378,258 @@ def conform_time_varying_and_time_invariant_matrices(A, B):
|
374 | 378 | return A, B
|
375 | 379 |
|
376 | 380 |
|
| 381 | +def normalize_axis(x, axis): |
| 382 | + """ |
| 383 | + Convert negative axis values to positive axis values |
| 384 | + """ |
| 385 | + if isinstance(axis, tuple): |
| 386 | + return tuple([normalize_axis(x, i) for i in axis]) |
| 387 | + if axis < 0: |
| 388 | + axis = x.ndim + axis |
| 389 | + return axis |
| 390 | + |
| 391 | + |
| 392 | +def reorder_from_labels( |
| 393 | + x: TensorVariable, |
| 394 | + labels: list[str], |
| 395 | + ordered_labels: list[str], |
| 396 | + labeled_axis: int | tuple[int, int], |
| 397 | +) -> TensorVariable: |
| 398 | + """ |
| 399 | + Reorder an input tensor along request axis/axes based on lists of string labels |
| 400 | +
|
| 401 | + Parameters |
| 402 | + ---------- |
| 403 | + x: TensorVariable |
| 404 | + Input tensor |
| 405 | + labels: list of str |
| 406 | + Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have |
| 407 | + ``x.shape[labeled_axis] == len(labels)`` |
| 408 | + ordered_labels: list of str |
| 409 | + Target ordering according to which ``x`` will be reordered. |
| 410 | + labeled_axis: int or tuple of int |
| 411 | + Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and |
| 412 | + and reorganization will be done on all requested axes together (NOT fancy indexing!) |
| 413 | +
|
| 414 | + Returns |
| 415 | + ------- |
| 416 | + x_sorted: TensorVariable |
| 417 | + Output tensor sorted along ``labeled_axis`` according to ``ordered_labels`` |
| 418 | + """ |
| 419 | + n_out = len(ordered_labels) |
| 420 | + label_to_index = {label: index for index, label in enumerate(ordered_labels)} |
| 421 | + |
| 422 | + missing_labels = [label for label in ordered_labels if label not in labels] |
| 423 | + indices = np.argsort([label_to_index[label] for label in [*labels, *missing_labels]]) |
| 424 | + |
| 425 | + if isinstance(labeled_axis, int): |
| 426 | + labeled_axis = (labeled_axis,) |
| 427 | + |
| 428 | + if indices.tolist() != list(range(n_out)): |
| 429 | + for axis in labeled_axis: |
| 430 | + idx = np.s_[tuple([slice(None, None) if i != axis else indices for i in range(x.ndim)])] |
| 431 | + x = x[idx] |
| 432 | + |
| 433 | + return x |
| 434 | + |
| 435 | + |
| 436 | +def pad_and_reorder( |
| 437 | + x: TensorVariable, labels: list[str], ordered_labels: list[str], labeled_axis: int |
| 438 | +) -> TensorVariable: |
| 439 | + """ |
| 440 | + Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the |
| 441 | + padded dimension to match the ordering in ``ordered_labels``. |
| 442 | +
|
| 443 | + Parameters |
| 444 | + ---------- |
| 445 | + x: TensorVariable |
| 446 | + Input tensor |
| 447 | + labels: list of str |
| 448 | + String labels associated with the `x` tensor at the ``labeled_axis`` dimension. At runtime, should have |
| 449 | + ``x.shape[labeled_axis] == len(labels)``. ``labels`` should be a subset of ``ordered_labels``. |
| 450 | + ordered_labels: list of str |
| 451 | + Target ordering according to which ``x`` will be reordered. |
| 452 | + labeled_axis: int |
| 453 | + Axis along which ``x`` will be labeled. |
| 454 | +
|
| 455 | + Returns |
| 456 | + ------- |
| 457 | + x_padded: TensorVariable |
| 458 | + Output tensor padded along ``labeled_axis`` according to ``ordered_labels``, then reordered. |
| 459 | +
|
| 460 | + """ |
| 461 | + n_out = len(ordered_labels) |
| 462 | + n_missing = n_out - len(labels) |
| 463 | + |
| 464 | + if n_missing > 0: |
| 465 | + zeros = pt.zeros( |
| 466 | + tuple([x.shape[i] if i != labeled_axis else n_missing for i in range(x.ndim)]) |
| 467 | + ) |
| 468 | + x_padded = pt.concatenate([x, zeros], axis=labeled_axis) |
| 469 | + else: |
| 470 | + x_padded = x |
| 471 | + |
| 472 | + return reorder_from_labels(x_padded, labels, ordered_labels, labeled_axis) |
| 473 | + |
| 474 | + |
| 475 | +def ndim_pad_and_reorder( |
| 476 | + x: TensorVariable, |
| 477 | + labels: list[str], |
| 478 | + ordered_labels: list[str], |
| 479 | + labeled_axis: int | tuple[int, int], |
| 480 | +) -> TensorVariable: |
| 481 | + """ |
| 482 | + Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the |
| 483 | + padded dimension to match the ordering in ``ordered_labels``. |
| 484 | +
|
| 485 | + Unlike ``pad_and_reorder``, this function allows padding and reordering to be done simultaneously on multiple |
| 486 | + axes. In this case, reordering is done jointly on all axes -- it does *not* use fancy indexing. |
| 487 | +
|
| 488 | + Parameters |
| 489 | + ---------- |
| 490 | + x: TensorVariable |
| 491 | + Input tensor |
| 492 | + labels: list of str |
| 493 | + Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have |
| 494 | + ``x.shape[labeled_axis] == len(labels)``. If ``labeled_axis`` is a tuple, all axes are assumed to have the |
| 495 | + same labels. |
| 496 | + ordered_labels: list of str |
| 497 | + Target ordering according to which ``x`` will be reordered. ``labels`` should be a subset of ``ordered_labels``. |
| 498 | + labeled_axis: int or tuple of int |
| 499 | + Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and |
| 500 | + and reorganization will be done on all requested axes together (NOT fancy indexing!) |
| 501 | +
|
| 502 | + Returns |
| 503 | + ------- |
| 504 | + x_sorted: TensorVariable |
| 505 | + Output tensor. Each ``labeled_axis`` is padded to the length of ``ordered_labels``, then reordered. |
| 506 | + """ |
| 507 | + n_missing = len(ordered_labels) - len(labels) |
| 508 | + |
| 509 | + if isinstance(labeled_axis, int): |
| 510 | + labeled_axis = (labeled_axis,) |
| 511 | + |
| 512 | + if n_missing > 0: |
| 513 | + pad_size = [(0, 0) if i not in labeled_axis else (0, n_missing) for i in range(x.ndim)] |
| 514 | + x = pt.pad(x, pad_size, mode="constant", constant_values=0) |
| 515 | + |
| 516 | + return reorder_from_labels(x, labels, ordered_labels, labeled_axis) |
| 517 | + |
| 518 | + |
| 519 | +def add_tensors_by_dim_labels( |
| 520 | + tensor: TensorVariable, |
| 521 | + other_tensor: TensorVariable, |
| 522 | + labels: list[str], |
| 523 | + other_labels: list[str], |
| 524 | + labeled_axis: int | tuple[int, int] = -1, |
| 525 | +) -> TensorVariable: |
| 526 | + """ |
| 527 | + Add two tensors based on labels associated with one dimension. |
| 528 | +
|
| 529 | + When combining statespace matrices associated with structural components with potentially different states, it is |
| 530 | + important to make sure that duplicated states are handled correctly. For bias vectors and covariance matrices, |
| 531 | + duplicated states should be summed. |
| 532 | +
|
| 533 | + When a state appears in one component but not another, that state should be treated as an implicit zero in the |
| 534 | + components where the state does not appear. This amounts to padding the relevant matrices with zeros before |
| 535 | + performing the addition. |
| 536 | +
|
| 537 | + When labeled_axis is a tuple, each provided label is assumed to be identically labeled in each input tensor. This |
| 538 | + is the case, for example, when working with a covariance matrix. In this case, padding and alignment will be |
| 539 | + done on each indicated index. |
| 540 | +
|
| 541 | + Parameters |
| 542 | + ---------- |
| 543 | + tensor: TensorVariable |
| 544 | + A statespace matrix to be summed with ``other_matrix``. |
| 545 | + other_tensor: TensorVariable |
| 546 | + A statespace matrix to be summed with ``matrix``. |
| 547 | + labels: list of str |
| 548 | + Dimension labels associated with ``matrix``, on the ``labeled_axis`` dimension. |
| 549 | + other_labels: list of str |
| 550 | + Dimension labels associated with ``other_matrix``, on the ``labeled_axis`` dimension. |
| 551 | + labeled_axis: int or tuple of int |
| 552 | + Dimension that is labeled by ``labels`` and ``other_labels``. ``matrix.shape[labeled_axis]`` must have the |
| 553 | + shape of ``len(labels)`` at runtime. |
| 554 | +
|
| 555 | + Returns |
| 556 | + ------- |
| 557 | + result: TensorVariable |
| 558 | + Result of addition of ``matrix`` and ``other_matrix``, along the ``labeled_axis`` dimension. The ordering of |
| 559 | + the output will be ``labels + [label for label in other_labels if label not in labels]``. That is, ``labels`` |
| 560 | + come first, followed by any new labels introduced by ``other_labels``. |
| 561 | +
|
| 562 | + """ |
| 563 | + labeled_axis = normalize_axis(tensor, labeled_axis) |
| 564 | + new_labels = [label for label in other_labels if label not in labels] |
| 565 | + combined_labels = type_cast(list[str], [*labels, *new_labels]) |
| 566 | + |
| 567 | + # If there is no overlap at all, directly concatenate the two matrices -- there's no need to worry about the order |
| 568 | + # of things, or padding. This is equivalent to padding both out with zeros then adding them. |
| 569 | + if combined_labels == [*labels, *other_labels]: |
| 570 | + if isinstance(labeled_axis, int): |
| 571 | + return pt.concatenate([tensor, other_tensor], axis=labeled_axis) |
| 572 | + else: |
| 573 | + # In the case where we want to align multiple dimensions, use block_diag to accomplish padding on the last |
| 574 | + # two dimensions |
| 575 | + dims = [*[i for i in range(tensor.ndim) if i not in labeled_axis], *labeled_axis] |
| 576 | + return pt.linalg.block_diag( |
| 577 | + type_cast(TensorVariable, tensor.transpose(*dims)), |
| 578 | + type_cast(TensorVariable, other_tensor.transpose(*dims)), |
| 579 | + ) |
| 580 | + # Otherwise, there are two possibilities. If all labels are the same, we might need to re-order one or both to get |
| 581 | + # them to agree. If *some* labels are the same, we will need to pad first, then potentially re-order. In any case, |
| 582 | + # the final step is just to add the padded and re-ordered tensors. |
| 583 | + fn = pad_and_reorder if isinstance(labeled_axis, int) else ndim_pad_and_reorder |
| 584 | + |
| 585 | + padded_tensor = fn( |
| 586 | + tensor, |
| 587 | + labels=type_cast(list[str], labels), |
| 588 | + ordered_labels=combined_labels, |
| 589 | + labeled_axis=labeled_axis, |
| 590 | + ) |
| 591 | + padded_tensor.name = tensor.name |
| 592 | + |
| 593 | + padded_other_tensor = fn( |
| 594 | + other_tensor, |
| 595 | + labels=type_cast(list[str], other_labels), |
| 596 | + ordered_labels=combined_labels, |
| 597 | + labeled_axis=labeled_axis, |
| 598 | + ) |
| 599 | + |
| 600 | + padded_other_tensor.name = other_tensor.name |
| 601 | + |
| 602 | + return padded_tensor + padded_other_tensor |
| 603 | + |
| 604 | + |
| 605 | +def join_tensors_by_dim_labels( |
| 606 | + tensor: TensorVariable, |
| 607 | + other_tensor: TensorVariable, |
| 608 | + labels: list[str], |
| 609 | + other_labels: list[str], |
| 610 | + labeled_axis: int = -1, |
| 611 | + join_axis: int = -1, |
| 612 | + block_diag_join: bool = False, |
| 613 | +) -> TensorVariable: |
| 614 | + labeled_axis = normalize_axis(tensor, labeled_axis) |
| 615 | + new_labels = [label for label in other_labels if label not in labels] |
| 616 | + combined_labels = [*labels, *new_labels] |
| 617 | + |
| 618 | + # Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros |
| 619 | + # everywhere they are needed -- no other sorting or padding necessary |
| 620 | + if combined_labels == [*labels, *other_labels]: |
| 621 | + return pt.linalg.block_diag(tensor, other_tensor) |
| 622 | + |
| 623 | + # Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out. |
| 624 | + tensor = ndim_pad_and_reorder(tensor, labels, combined_labels, labeled_axis) |
| 625 | + other_tensor = ndim_pad_and_reorder(other_tensor, other_labels, combined_labels, labeled_axis) |
| 626 | + |
| 627 | + if block_diag_join: |
| 628 | + return pt.linalg.block_diag(tensor, other_tensor) |
| 629 | + else: |
| 630 | + return pt.concatenate([tensor, other_tensor], axis=join_axis) |
| 631 | + |
| 632 | + |
377 | 633 | def get_exog_dims_from_idata(exog_name, idata):
|
378 | 634 | if exog_name in idata.posterior.data_vars:
|
379 | 635 | exog_dims = idata.posterior[exog_name].dims[2:]
|
|
0 commit comments