|
10 | 10 |
|
11 | 11 | import pytensor |
12 | 12 | import pytensor.tensor as pt |
| 13 | +from pytensor.gradient import DisconnectedType |
13 | 14 | from pytensor.graph.basic import Apply |
14 | 15 | from pytensor.graph.op import Op |
15 | 16 | from pytensor.tensor import TensorLike, as_tensor_variable |
@@ -303,6 +304,7 @@ def L_op(self, inputs, outputs, output_gradients): |
303 | 304 | } |
304 | 305 | ) |
305 | 306 | b_bar = trans_solve_op(A.T, c_bar) |
| 307 | + |
306 | 308 | # force outer product if vector second input |
307 | 309 | A_bar = -ptm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T) |
308 | 310 |
|
@@ -381,6 +383,188 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): |
381 | 383 | )(A, b) |
382 | 384 |
|
383 | 385 |
|
| 386 | +class LU(Op): |
| 387 | + """Decompose a matrix into lower and upper triangular matrices.""" |
| 388 | + |
| 389 | + __props__ = ("permute_l", "overwrite_a", "check_finite", "p_indices") |
| 390 | + |
| 391 | + def __init__( |
| 392 | + self, *, permute_l=False, overwrite_a=False, check_finite=True, p_indices=False |
| 393 | + ): |
| 394 | + if permute_l and p_indices: |
| 395 | + raise ValueError("Only one of permute_l and p_indices can be True") |
| 396 | + self.permute_l = permute_l |
| 397 | + self.check_finite = check_finite |
| 398 | + self.p_indices = p_indices |
| 399 | + self.overwrite_a = overwrite_a |
| 400 | + |
| 401 | + if self.permute_l: |
| 402 | + # permute_l overrides p_indices in the scipy function. We can copy that behavior |
| 403 | + self.gufunc_signature = "(m,m)->(m,m),(m,m)" |
| 404 | + elif self.p_indices: |
| 405 | + self.gufunc_signature = "(m,m)->(m),(m,m),(m,m)" |
| 406 | + else: |
| 407 | + self.gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)" |
| 408 | + |
| 409 | + if self.overwrite_a: |
| 410 | + self.destroy_map = {0: [0]} |
| 411 | + |
| 412 | + def infer_shape(self, fgraph, node, shapes): |
| 413 | + n = shapes[0][0] |
| 414 | + if self.permute_l: |
| 415 | + return [(n, n), (n, n)] |
| 416 | + elif self.p_indices: |
| 417 | + return [(n,), (n, n), (n, n)] |
| 418 | + else: |
| 419 | + return [(n, n), (n, n), (n, n)] |
| 420 | + |
| 421 | + def make_node(self, x): |
| 422 | + x = as_tensor_variable(x) |
| 423 | + if x.type.ndim != 2: |
| 424 | + raise TypeError( |
| 425 | + f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" |
| 426 | + ) |
| 427 | + |
| 428 | + real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" |
| 429 | + p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) |
| 430 | + |
| 431 | + L = tensor(shape=x.type.shape, dtype=x.type.dtype) |
| 432 | + U = tensor(shape=x.type.shape, dtype=x.type.dtype) |
| 433 | + |
| 434 | + if self.permute_l: |
| 435 | + # In this case, L is actually P @ L |
| 436 | + return Apply(self, inputs=[x], outputs=[L, U]) |
| 437 | + if self.p_indices: |
| 438 | + p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype) |
| 439 | + return Apply(self, inputs=[x], outputs=[p_indices, L, U]) |
| 440 | + |
| 441 | + P = tensor(shape=x.type.shape, dtype=p_dtype) |
| 442 | + return Apply(self, inputs=[x], outputs=[P, L, U]) |
| 443 | + |
| 444 | + def perform(self, node, inputs, outputs): |
| 445 | + [A] = inputs |
| 446 | + |
| 447 | + out = scipy_linalg.lu( |
| 448 | + A, |
| 449 | + permute_l=self.permute_l, |
| 450 | + overwrite_a=self.overwrite_a, |
| 451 | + check_finite=self.check_finite, |
| 452 | + p_indices=self.p_indices, |
| 453 | + ) |
| 454 | + |
| 455 | + outputs[0][0] = out[0] |
| 456 | + outputs[1][0] = out[1] |
| 457 | + |
| 458 | + if not self.permute_l: |
| 459 | + # In all cases except permute_l, there are three returns |
| 460 | + outputs[2][0] = out[2] |
| 461 | + |
| 462 | + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": |
| 463 | + if 0 in allowed_inplace_inputs: |
| 464 | + new_props = self._props_dict() # type: ignore |
| 465 | + new_props["overwrite_a"] = True |
| 466 | + return type(self)(**new_props) |
| 467 | + else: |
| 468 | + return self |
| 469 | + |
| 470 | + def L_op( |
| 471 | + self, |
| 472 | + inputs: Sequence[ptb.Variable], |
| 473 | + outputs: Sequence[ptb.Variable], |
| 474 | + output_grads: Sequence[ptb.Variable], |
| 475 | + ) -> list[ptb.Variable]: |
| 476 | + r""" |
| 477 | + Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization |
| 478 | + F. R. De Hoog, R.S. Anderssen, M. A. Lukas |
| 479 | + """ |
| 480 | + [A] = inputs |
| 481 | + A = cast(TensorVariable, A) |
| 482 | + |
| 483 | + if self.permute_l: |
| 484 | + # P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar |
| 485 | + L_bar, U_bar = output_grads |
| 486 | + |
| 487 | + # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient |
| 488 | + # We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass |
| 489 | + P_or_indices, L, U = lu( # type: ignore |
| 490 | + A, permute_l=False, check_finite=self.check_finite, p_indices=False |
| 491 | + ) |
| 492 | + |
| 493 | + else: |
| 494 | + # In both other cases, there are 3 outputs. The first output will either be the permutation index itself, |
| 495 | + # or indices that can be used to reconstruct the permutation matrix. |
| 496 | + P_or_indices, L, U = outputs |
| 497 | + _, L_bar, U_bar = output_grads |
| 498 | + |
| 499 | + L_bar = ( |
| 500 | + L_bar if not isinstance(L_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 501 | + ) |
| 502 | + U_bar = ( |
| 503 | + U_bar if not isinstance(U_bar.type, DisconnectedType) else pt.zeros_like(A) |
| 504 | + ) |
| 505 | + |
| 506 | + x1 = ptb.tril(L.T @ L_bar, k=-1) |
| 507 | + x2 = ptb.triu(U_bar @ U.T) |
| 508 | + |
| 509 | + LT_inv_x = solve_triangular(L.T, x1 + x2, lower=False, unit_diagonal=True) |
| 510 | + |
| 511 | + # Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation |
| 512 | + B_bar = solve_triangular(U, LT_inv_x.T, lower=False).T |
| 513 | + |
| 514 | + if not self.p_indices: |
| 515 | + A_bar = P_or_indices @ B_bar |
| 516 | + else: |
| 517 | + A_bar = B_bar[P_or_indices] |
| 518 | + |
| 519 | + return [A_bar] |
| 520 | + |
| 521 | + |
| 522 | +def lu( |
| 523 | + a: TensorLike, permute_l=False, check_finite=True, p_indices=False |
| 524 | +) -> ( |
| 525 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 526 | + | tuple[TensorVariable, TensorVariable] |
| 527 | +): |
| 528 | + """ |
| 529 | + Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix: |
| 530 | +
|
| 531 | + ... math:: |
| 532 | +
|
| 533 | + A = P L U |
| 534 | +
|
| 535 | + Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular. |
| 536 | +
|
| 537 | + Parameters |
| 538 | + ---------- |
| 539 | + a: TensorLike |
| 540 | + Matrix to be factorized |
| 541 | + permute_l: bool |
| 542 | + If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will |
| 543 | + be returned in this case, and PL will not be lower triangular. |
| 544 | + check_finite: bool |
| 545 | + Whether to check that the input matrix contains only finite numbers. |
| 546 | + p_indices: bool |
| 547 | + If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix |
| 548 | + itself. |
| 549 | +
|
| 550 | + Returns |
| 551 | + ------- |
| 552 | + P: TensorVariable |
| 553 | + Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True. |
| 554 | + L: TensorVariable |
| 555 | + Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True. |
| 556 | + U: TensorVariable |
| 557 | + Upper triangular matrix |
| 558 | + """ |
| 559 | + return cast( |
| 560 | + tuple[TensorVariable, TensorVariable, TensorVariable] |
| 561 | + | tuple[TensorVariable, TensorVariable], |
| 562 | + Blockwise( |
| 563 | + LU(permute_l=permute_l, p_indices=p_indices, check_finite=check_finite) |
| 564 | + )(a), |
| 565 | + ) |
| 566 | + |
| 567 | + |
384 | 568 | class SolveTriangular(SolveBase): |
385 | 569 | """Solve a system of linear equations.""" |
386 | 570 |
|
|
0 commit comments