|
1 | 1 | import logging |
2 | | -import typing |
3 | 2 | import warnings |
| 3 | +from collections.abc import Sequence |
4 | 4 | from functools import reduce |
5 | 5 | from typing import Literal, cast |
6 | 6 |
|
|
9 | 9 |
|
10 | 10 | import pytensor |
11 | 11 | import pytensor.tensor as pt |
| 12 | +from pytensor.gradient import DisconnectedType |
12 | 13 | from pytensor.graph.basic import Apply |
13 | 14 | from pytensor.graph.op import Op |
14 | 15 | from pytensor.tensor import TensorLike, as_tensor_variable |
@@ -295,31 +296,16 @@ def L_op(self, inputs, outputs, output_gradients): |
295 | 296 | # We need to return (dC/d[inv(A)], dC/db) |
296 | 297 | c_bar = output_gradients[0] |
297 | 298 |
|
298 | | - solve_args = {k: getattr(self, k) for k in self.__props__} |
299 | | - |
300 | | - # Some solvers can solve A.T x = b directly, without ever computing the transpose |
301 | | - has_trans = "transposed" in self.__props__ |
302 | | - |
303 | | - if has_trans: |
304 | | - # If the solver can do transposed solves, we do the opposite of the forward in the reverse. If we solved |
305 | | - # C = solve(A, b), then b_bar = solve(A.T, c_bar). If we solved C = solve(A.T, b), then |
306 | | - # b_bar = solve(A, c_bar) |
307 | | - solve_args["transposed"] = not solve_args["transposed"] |
308 | | - solve_op = type(self)(**solve_args) |
309 | | - b_bar = solve_op(A, c_bar) |
310 | | - |
311 | | - else: |
312 | | - # Otherwise, we have to actually do the transpose of whatever was given |
313 | | - solve_op = type(self)(**solve_args) |
314 | | - b_bar = solve_op(A.T, c_bar) |
| 299 | + trans_solve_op = type(self)( |
| 300 | + **{ |
| 301 | + k: (not getattr(self, k) if k == "lower" else getattr(self, k)) |
| 302 | + for k in self.__props__ |
| 303 | + } |
| 304 | + ) |
| 305 | + b_bar = trans_solve_op(A.T, c_bar) |
315 | 306 |
|
316 | 307 | # force outer product if vector second input |
317 | | - A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar @ c.T |
318 | | - |
319 | | - if has_trans and not solve_args["transposed"]: |
320 | | - # If we did a transposed solve in the forward pass, the program is expecting the |
321 | | - # gradients of A.T, not A |
322 | | - A_bar = A_bar.T |
| 308 | + A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) |
323 | 309 |
|
324 | 310 | return [A_bar, b_bar] |
325 | 311 |
|
@@ -396,6 +382,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): |
396 | 382 | )(A, b) |
397 | 383 |
|
398 | 384 |
|
| 385 | +class LU(Op): |
| 386 | + """Decompose a matrix into lower and upper triangular matrices.""" |
| 387 | + |
| 388 | + __props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") |
| 389 | + |
| 390 | + def __init__( |
| 391 | + self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False |
| 392 | + ): |
| 393 | + self.permute_l = permute_l |
| 394 | + self.check_finite = check_finite |
| 395 | + self.p_indices = p_indices |
| 396 | + self.overwrite_a = overwrite_a |
| 397 | + |
| 398 | + if self.permute_l: |
| 399 | + # permute_l overrides p_indices in the scipy function. We can copy that behavior |
| 400 | + self.gufunc_signature = "(m,m)->(m,m),(m,m)" |
| 401 | + elif self.p_indices: |
| 402 | + self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)" |
| 403 | + else: |
| 404 | + self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)" |
| 405 | + |
| 406 | + if self.overwrite_a: |
| 407 | + self.destroy_map = {0: [0]} |
| 408 | + |
| 409 | + def infer_shape(self, fgraph, node, shapes): |
| 410 | + n = shapes[0][0] |
| 411 | + if self.permute_l: |
| 412 | + return [(n, n), (n, n)] |
| 413 | + elif self.p_indices: |
| 414 | + return [(n,), (n, n), (n, n)] |
| 415 | + else: |
| 416 | + return [(n, n), (n, n), (n, n)] |
| 417 | + |
| 418 | + def make_node(self, x): |
| 419 | + x = as_tensor_variable(x) |
| 420 | + if x.type.ndim != 2: |
| 421 | + raise TypeError( |
| 422 | + f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" |
| 423 | + ) |
| 424 | + |
| 425 | + real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" |
| 426 | + p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) |
| 427 | + |
| 428 | + L = tensor(shape=x.type.shape, dtype=real_dtype) |
| 429 | + U = tensor(shape=x.type.shape, dtype=real_dtype) |
| 430 | + |
| 431 | + if self.permute_l: |
| 432 | + # In this case, L is actually P @ L |
| 433 | + return Apply(self, inputs=[x], outputs=[L, U]) |
| 434 | + elif self.p_indices: |
| 435 | + p = tensor(shape=(x.type.shape[0],), dtype=p_dtype) |
| 436 | + return Apply(self, inputs=[x], outputs=[p, L, U]) |
| 437 | + else: |
| 438 | + P = tensor(shape=x.type.shape, dtype=p_dtype) |
| 439 | + return Apply(self, inputs=[x], outputs=[P, L, U]) |
| 440 | + |
| 441 | + def perform(self, node, inputs, outputs): |
| 442 | + [A] = inputs |
| 443 | + |
| 444 | + out = scipy.linalg.lu( |
| 445 | + A, |
| 446 | + permute_l=self.permute_l, |
| 447 | + overwrite_a=self.overwrite_a, |
| 448 | + check_finite=self.check_finite, |
| 449 | + p_indices=self.p_indices, |
| 450 | + ) |
| 451 | + |
| 452 | + outputs[0][0] = out[0] |
| 453 | + outputs[1][0] = out[1] |
| 454 | + |
| 455 | + if not self.permute_l: |
| 456 | + # In all cases except permute_l, there are three returns |
| 457 | + outputs[2][0] = out[2] |
| 458 | + |
| 459 | + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": |
| 460 | + if 0 in allowed_inplace_inputs: |
| 461 | + new_props = self._props_dict() # type: ignore |
| 462 | + new_props["overwrite_a"] = True |
| 463 | + return type(self)(**new_props) |
| 464 | + else: |
| 465 | + return self |
| 466 | + |
| 467 | + def L_op( |
| 468 | + self, |
| 469 | + inputs: Sequence[ptb.Variable], |
| 470 | + outputs: Sequence[ptb.Variable], |
| 471 | + output_grads: Sequence[ptb.Variable], |
| 472 | + ) -> list[ptb.Variable]: |
| 473 | + r""" |
| 474 | + Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization |
| 475 | + F. R. De Hoog, R.S. Anderssen, M. A. Lukas |
| 476 | + """ |
| 477 | + [A] = inputs |
| 478 | + A = cast(TensorVariable, A) |
| 479 | + |
| 480 | + if self.permute_l: |
| 481 | + PL_bar, U_bar = output_grads |
| 482 | + |
| 483 | + # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient |
| 484 | + P, L, U = lu( # type: ignore |
| 485 | + A, permute_l=False, check_finite=self.check_finite, p_indices=False |
| 486 | + ) |
| 487 | + |
| 488 | + # Permutation matrix is orthogonal |
| 489 | + L_bar = ( |
| 490 | + P.T @ PL_bar |
| 491 | + if not isinstance(PL_bar.type, DisconnectedType) |
| 492 | + else pt.zeros_like(A) |
| 493 | + ) |
| 494 | + |
| 495 | + elif self.p_indices: |
| 496 | + p, L, U = outputs |
| 497 | + |
| 498 | + # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient |
| 499 | + P = pt.eye(A.shape[0])[p] |
| 500 | + _, L_bar, U_bar = output_grads |
| 501 | + else: |
| 502 | + P, L, U = outputs |
| 503 | + _, L_bar, U_bar = output_grads |
| 504 | + |
| 505 | + L_bar = ( |
| 506 | + L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 507 | + ) |
| 508 | + U_bar = ( |
| 509 | + U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 510 | + ) |
| 511 | + |
| 512 | + x1 = ptb.tril(L.T @ L_bar, k=-1) |
| 513 | + x2 = ptb.triu(U_bar @ U.T) |
| 514 | + |
| 515 | + L_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) |
| 516 | + A_bar = P @ solve_triangular(U, L_inv_x.T, lower=False).T |
| 517 | + |
| 518 | + return [A_bar] |
| 519 | + |
| 520 | + |
| 521 | +def lu( |
| 522 | + a: TensorLike, permute_l=False, check_finite=True, p_indices=False |
| 523 | +) -> ( |
| 524 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 525 | + | tuple[TensorVariable, TensorVariable] |
| 526 | +): |
| 527 | + """ |
| 528 | + Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix: |
| 529 | +
|
| 530 | + ... math:: |
| 531 | +
|
| 532 | + A = P L U |
| 533 | +
|
| 534 | + Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular. |
| 535 | +
|
| 536 | + Parameters |
| 537 | + ---------- |
| 538 | + a: TensorLike |
| 539 | + Matrix to be factorized |
| 540 | + permute_l: bool |
| 541 | + If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will |
| 542 | + be returned in this case, and PL will not be lower triangular. |
| 543 | + check_finite: bool |
| 544 | + Whether to check that the input matrix contains only finite numbers. |
| 545 | + p_indices: bool |
| 546 | + If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix |
| 547 | + itself. |
| 548 | +
|
| 549 | + Returns |
| 550 | + ------- |
| 551 | + P: TensorVariable |
| 552 | + Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True. |
| 553 | + L: TensorVariable |
| 554 | + Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True. |
| 555 | + U: TensorVariable |
| 556 | + Upper triangular matrix |
| 557 | + """ |
| 558 | + return cast( |
| 559 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 560 | + | tuple[TensorVariable, TensorVariable], |
| 561 | + LU(permute_l=permute_l, check_finite=check_finite, p_indices=p_indices)(a), |
| 562 | + ) |
| 563 | + |
| 564 | + |
399 | 565 | class SolveTriangular(SolveBase): |
400 | 566 | """Solve a system of linear equations.""" |
401 | 567 |
|
@@ -513,13 +679,13 @@ class Solve(SolveBase): |
513 | 679 | def __init__(self, *, assume_a="gen", transposed=False, **kwargs): |
514 | 680 | if assume_a not in ("gen", "sym", "her", "pos"): |
515 | 681 | raise ValueError(f"{assume_a} is not a recognized matrix structure") |
| 682 | + |
516 | 683 | super().__init__(**kwargs) |
517 | 684 | self.assume_a = assume_a |
518 | 685 | self.transposed = transposed |
519 | 686 |
|
520 | 687 | def perform(self, node, inputs, outputs): |
521 | 688 | a, b = inputs |
522 | | - |
523 | 689 | outputs[0][0] = scipy.linalg.solve( |
524 | 690 | a=a, |
525 | 691 | b=b, |
@@ -1083,7 +1249,7 @@ def solve_discrete_are( |
1083 | 1249 | ) |
1084 | 1250 |
|
1085 | 1251 |
|
1086 | | -def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: |
| 1252 | +def _largest_common_dtype(tensors: Sequence[TensorVariable]) -> np.dtype: |
1087 | 1253 | return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) |
1088 | 1254 |
|
1089 | 1255 |
|
|
0 commit comments