|
1 | 1 | import logging |
2 | 2 | import typing |
3 | 3 | import warnings |
| 4 | +from collections.abc import Sequence |
4 | 5 | from functools import reduce |
5 | 6 | from typing import Literal, cast |
6 | 7 |
|
7 | 8 | import numpy as np |
8 | | -import scipy.linalg |
| 9 | +import scipy |
9 | 10 |
|
10 | 11 | import pytensor |
11 | 12 | import pytensor.tensor as pt |
| 13 | +from pytensor import Variable |
| 14 | +from pytensor.gradient import DisconnectedType |
12 | 15 | from pytensor.graph.basic import Apply |
13 | 16 | from pytensor.graph.op import Op |
14 | 17 | from pytensor.tensor import TensorLike, as_tensor_variable |
|
25 | 28 |
|
26 | 29 |
|
27 | 30 | class Cholesky(Op): |
28 | | - # TODO: LAPACK wrapper with in-place behavior, for solve also |
29 | | - |
30 | 31 | __props__ = ("lower", "check_finite", "on_error", "overwrite_a") |
31 | 32 | gufunc_signature = "(m,m)->(m,m)" |
32 | 33 |
|
@@ -396,6 +397,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): |
396 | 397 | )(A, b) |
397 | 398 |
|
398 | 399 |
|
| 400 | +class LU(Op): |
| 401 | + """Decompose a matrix into lower and upper triangular matrices.""" |
| 402 | + |
| 403 | + __props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") |
| 404 | + |
| 405 | + def __init__( |
| 406 | + self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False |
| 407 | + ): |
| 408 | + self.permute_l = permute_l |
| 409 | + self.check_finite = check_finite |
| 410 | + self.p_indices = p_indices |
| 411 | + self.overwrite_a = overwrite_a |
| 412 | + |
| 413 | + if self.permute_l: |
| 414 | + # permute_l overrides p_indices in the scipy function. We can copy that behavior |
| 415 | + self.gufunc_signature = "(m,m)->(m,m),(m,m)" |
| 416 | + elif self.p_indices: |
| 417 | + self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)" |
| 418 | + else: |
| 419 | + self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)" |
| 420 | + |
| 421 | + if self.overwrite_a: |
| 422 | + self.destroy_map = {0: [0]} |
| 423 | + |
| 424 | + def infer_shape(self, fgraph, node, shapes): |
| 425 | + n = shapes[0][0] |
| 426 | + if self.permute_l: |
| 427 | + return [(n, n), (n, n)] |
| 428 | + elif self.p_indices: |
| 429 | + return [(n,), (n, n), (n, n)] |
| 430 | + else: |
| 431 | + return [(n, n), (n, n), (n, n)] |
| 432 | + |
| 433 | + def make_node(self, x): |
| 434 | + x = as_tensor_variable(x) |
| 435 | + if x.type.ndim != 2: |
| 436 | + raise TypeError( |
| 437 | + f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" |
| 438 | + ) |
| 439 | + |
| 440 | + real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" |
| 441 | + p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) |
| 442 | + |
| 443 | + L = tensor(shape=x.type.shape, dtype=real_dtype) |
| 444 | + U = tensor(shape=x.type.shape, dtype=real_dtype) |
| 445 | + |
| 446 | + if self.permute_l: |
| 447 | + # In this case, L is actually P @ L |
| 448 | + return Apply(self, inputs=[x], outputs=[L, U]) |
| 449 | + elif self.p_indices: |
| 450 | + p = tensor(shape=(x.type.shape[0],), dtype=p_dtype) |
| 451 | + return Apply(self, inputs=[x], outputs=[p, L, U]) |
| 452 | + else: |
| 453 | + P = tensor(shape=x.type.shape, dtype=p_dtype) |
| 454 | + return Apply(self, inputs=[x], outputs=[P, L, U]) |
| 455 | + |
| 456 | + def perform(self, node, inputs, outputs): |
| 457 | + [A] = inputs |
| 458 | + |
| 459 | + out = scipy.linalg.lu( |
| 460 | + A, |
| 461 | + permute_l=self.permute_l, |
| 462 | + overwrite_a=self.overwrite_a, |
| 463 | + check_finite=self.check_finite, |
| 464 | + p_indices=self.p_indices, |
| 465 | + ) |
| 466 | + |
| 467 | + outputs[0][0] = out[0] |
| 468 | + outputs[1][0] = out[1] |
| 469 | + |
| 470 | + if not self.permute_l: |
| 471 | + # In all cases except permute_l, there are three returns |
| 472 | + outputs[2][0] = out[2] |
| 473 | + |
| 474 | + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": |
| 475 | + if 0 in allowed_inplace_inputs: |
| 476 | + new_props = self._props_dict() # type: ignore |
| 477 | + new_props["overwrite_a"] = True |
| 478 | + return type(self)(**new_props) |
| 479 | + else: |
| 480 | + return self |
| 481 | + |
| 482 | + def L_op( |
| 483 | + self, |
| 484 | + inputs: Sequence[Variable], |
| 485 | + outputs: Sequence[Variable], |
| 486 | + output_grads: Sequence[Variable], |
| 487 | + ) -> list[Variable]: |
| 488 | + r""" |
| 489 | + Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization |
| 490 | + F. R. De Hoog, R.S. Anderssen, M. A. Lukas |
| 491 | + """ |
| 492 | + [A] = inputs |
| 493 | + A = cast(TensorVariable, A) |
| 494 | + |
| 495 | + if self.permute_l: |
| 496 | + PL_bar, U_bar = output_grads |
| 497 | + |
| 498 | + # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient |
| 499 | + P, L, U = lu( # type: ignore |
| 500 | + A, permute_l=False, check_finite=self.check_finite, p_indices=False |
| 501 | + ) |
| 502 | + |
| 503 | + # Permutation matrix is orthogonal |
| 504 | + L_bar = ( |
| 505 | + P.T @ PL_bar |
| 506 | + if not isinstance(PL_bar.type, DisconnectedType) |
| 507 | + else pt.zeros_like(A) |
| 508 | + ) |
| 509 | + |
| 510 | + elif self.p_indices: |
| 511 | + p, L, U = outputs |
| 512 | + |
| 513 | + # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient |
| 514 | + P = pt.eye(A.shape[0])[p] |
| 515 | + _, L_bar, U_bar = output_grads |
| 516 | + else: |
| 517 | + P, L, U = outputs |
| 518 | + _, L_bar, U_bar = output_grads |
| 519 | + |
| 520 | + L_bar = ( |
| 521 | + L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 522 | + ) |
| 523 | + U_bar = ( |
| 524 | + U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 525 | + ) |
| 526 | + |
| 527 | + x1 = ptb.tril(L.T @ L_bar, k=-1) |
| 528 | + x2 = ptb.triu(U_bar @ U.T) |
| 529 | + |
| 530 | + L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) |
| 531 | + A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T |
| 532 | + |
| 533 | + return [A_bar] |
| 534 | + |
| 535 | + |
| 536 | +def lu( |
| 537 | + a: TensorLike, permute_l=False, check_finite=True, p_indices=False |
| 538 | +) -> ( |
| 539 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 540 | + | tuple[TensorVariable, TensorVariable] |
| 541 | +): |
| 542 | + """ |
| 543 | + Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix: |
| 544 | +
|
| 545 | + ... math:: |
| 546 | +
|
| 547 | + A = P L U |
| 548 | +
|
| 549 | + Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular. |
| 550 | +
|
| 551 | + Parameters |
| 552 | + ---------- |
| 553 | + a: TensorLike |
| 554 | + Matrix to be factorized |
| 555 | + permute_l: bool |
| 556 | + If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will |
| 557 | + be returned in this case, and PL will not be lower triangular. |
| 558 | + check_finite: bool |
| 559 | + Whether to check that the input matrix contains only finite numbers. |
| 560 | + p_indices: bool |
| 561 | + If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix |
| 562 | + itself. |
| 563 | +
|
| 564 | + Returns |
| 565 | + ------- |
| 566 | + P: TensorVariable |
| 567 | + Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True. |
| 568 | + L: TensorVariable |
| 569 | + Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True. |
| 570 | + U: TensorVariable |
| 571 | + Upper triangular matrix |
| 572 | + """ |
| 573 | + return cast( |
| 574 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 575 | + | tuple[TensorVariable, TensorVariable], |
| 576 | + LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a), |
| 577 | + ) |
| 578 | + |
| 579 | + |
399 | 580 | class SolveTriangular(SolveBase): |
400 | 581 | """Solve a system of linear equations.""" |
401 | 582 |
|
|
0 commit comments