|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | | -import pytensor.tensor as pt |
9 | 8 | from pytensor import scalar as ps |
10 | 9 | from pytensor.compile.builders import OpFromGraph |
11 | 10 | from pytensor.gradient import DisconnectedType |
12 | 11 | from pytensor.graph.basic import Apply |
13 | 12 | from pytensor.graph.op import Op |
14 | | -from pytensor.ifelse import ifelse |
15 | 13 | from pytensor.npy_2_compat import normalize_axis_tuple |
16 | | -from pytensor.raise_op import Assert |
17 | 14 | from pytensor.tensor import TensorLike |
18 | 15 | from pytensor.tensor import basic as ptb |
19 | 16 | from pytensor.tensor import math as ptm |
@@ -468,173 +465,6 @@ def eigh(a, UPLO="L"): |
468 | 465 | return Eigh(UPLO)(a) |
469 | 466 |
|
470 | 467 |
|
471 | | -class QRFull(Op): |
472 | | - """ |
473 | | - Full QR Decomposition. |
474 | | -
|
475 | | - Computes the QR decomposition of a matrix. |
476 | | - Factor the matrix a as qr, where q is orthonormal |
477 | | - and r is upper-triangular. |
478 | | -
|
479 | | - """ |
480 | | - |
481 | | - __props__ = ("mode",) |
482 | | - |
483 | | - def __init__(self, mode): |
484 | | - self.mode = mode |
485 | | - |
486 | | - def make_node(self, x): |
487 | | - x = as_tensor_variable(x) |
488 | | - |
489 | | - assert x.ndim == 2, "The input of qr function should be a matrix." |
490 | | - |
491 | | - in_dtype = x.type.numpy_dtype |
492 | | - out_dtype = np.dtype(f"f{in_dtype.itemsize}") |
493 | | - |
494 | | - q = matrix(dtype=out_dtype) |
495 | | - |
496 | | - if self.mode != "raw": |
497 | | - r = matrix(dtype=out_dtype) |
498 | | - else: |
499 | | - r = vector(dtype=out_dtype) |
500 | | - |
501 | | - if self.mode != "r": |
502 | | - q = matrix(dtype=out_dtype) |
503 | | - outputs = [q, r] |
504 | | - else: |
505 | | - outputs = [r] |
506 | | - |
507 | | - return Apply(self, [x], outputs) |
508 | | - |
509 | | - def perform(self, node, inputs, outputs): |
510 | | - (x,) = inputs |
511 | | - assert x.ndim == 2, "The input of qr function should be a matrix." |
512 | | - res = np.linalg.qr(x, self.mode) |
513 | | - if self.mode != "r": |
514 | | - outputs[0][0], outputs[1][0] = res |
515 | | - else: |
516 | | - outputs[0][0] = res |
517 | | - |
518 | | - def L_op(self, inputs, outputs, output_grads): |
519 | | - """ |
520 | | - Reverse-mode gradient of the QR function. |
521 | | -
|
522 | | - References |
523 | | - ---------- |
524 | | - .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ |
525 | | - .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 |
526 | | - """ |
527 | | - |
528 | | - from pytensor.tensor.slinalg import solve_triangular |
529 | | - |
530 | | - (A,) = (cast(ptb.TensorVariable, x) for x in inputs) |
531 | | - m, n = A.shape |
532 | | - |
533 | | - def _H(x: ptb.TensorVariable): |
534 | | - return x.conj().mT |
535 | | - |
536 | | - def _copyltu(x: ptb.TensorVariable): |
537 | | - return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) |
538 | | - |
539 | | - if self.mode == "raw": |
540 | | - raise NotImplementedError("Gradient of qr not implemented for mode=raw") |
541 | | - |
542 | | - elif self.mode == "r": |
543 | | - # We need all the components of the QR to compute the gradient of A even if we only |
544 | | - # use the upper triangular component in the cost function. |
545 | | - Q, R = qr(A, mode="reduced") |
546 | | - dQ = Q.zeros_like() |
547 | | - dR = cast(ptb.TensorVariable, output_grads[0]) |
548 | | - |
549 | | - else: |
550 | | - Q, R = (cast(ptb.TensorVariable, x) for x in outputs) |
551 | | - if self.mode == "complete": |
552 | | - qr_assert_op = Assert( |
553 | | - "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" |
554 | | - ) |
555 | | - R = qr_assert_op(R, ptm.le(m, n)) |
556 | | - |
557 | | - new_output_grads = [] |
558 | | - is_disconnected = [ |
559 | | - isinstance(x.type, DisconnectedType) for x in output_grads |
560 | | - ] |
561 | | - if all(is_disconnected): |
562 | | - # This should never be reached by Pytensor |
563 | | - return [DisconnectedType()()] # pragma: no cover |
564 | | - |
565 | | - for disconnected, output_grad, output in zip( |
566 | | - is_disconnected, output_grads, [Q, R], strict=True |
567 | | - ): |
568 | | - if disconnected: |
569 | | - new_output_grads.append(output.zeros_like()) |
570 | | - else: |
571 | | - new_output_grads.append(output_grad) |
572 | | - |
573 | | - (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) |
574 | | - |
575 | | - # gradient expression when m >= n |
576 | | - M = R @ _H(dR) - _H(dQ) @ Q |
577 | | - K = dQ + Q @ _copyltu(M) |
578 | | - A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) |
579 | | - |
580 | | - # gradient expression when m < n |
581 | | - Y = A[:, m:] |
582 | | - U = R[:, :m] |
583 | | - dU, dV = dR[:, :m], dR[:, m:] |
584 | | - dQ_Yt_dV = dQ + Y @ _H(dV) |
585 | | - M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q |
586 | | - X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) |
587 | | - Y_bar = Q @ dV |
588 | | - A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) |
589 | | - |
590 | | - return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] |
591 | | - |
592 | | - |
593 | | -def qr(a, mode="reduced"): |
594 | | - """ |
595 | | - Computes the QR decomposition of a matrix. |
596 | | - Factor the matrix a as qr, where q |
597 | | - is orthonormal and r is upper-triangular. |
598 | | -
|
599 | | - Parameters |
600 | | - ---------- |
601 | | - a : array_like, shape (M, N) |
602 | | - Matrix to be factored. |
603 | | -
|
604 | | - mode : {'reduced', 'complete', 'r', 'raw'}, optional |
605 | | - If K = min(M, N), then |
606 | | -
|
607 | | - 'reduced' |
608 | | - returns q, r with dimensions (M, K), (K, N) |
609 | | -
|
610 | | - 'complete' |
611 | | - returns q, r with dimensions (M, M), (M, N) |
612 | | -
|
613 | | - 'r' |
614 | | - returns r only with dimensions (K, N) |
615 | | -
|
616 | | - 'raw' |
617 | | - returns h, tau with dimensions (N, M), (K,) |
618 | | -
|
619 | | - Note that array h returned in 'raw' mode is |
620 | | - transposed for calling Fortran. |
621 | | -
|
622 | | - Default mode is 'reduced' |
623 | | -
|
624 | | - Returns |
625 | | - ------- |
626 | | - q : matrix of float or complex, optional |
627 | | - A matrix with orthonormal columns. When mode = 'complete' the |
628 | | - result is an orthogonal/unitary matrix depending on whether or |
629 | | - not a is real/complex. The determinant may be either +/- 1 in |
630 | | - that case. |
631 | | - r : matrix of float or complex, optional |
632 | | - The upper-triangular matrix. |
633 | | -
|
634 | | - """ |
635 | | - return QRFull(mode)(a) |
636 | | - |
637 | | - |
638 | 468 | class SVD(Op): |
639 | 469 | """ |
640 | 470 | Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V |
@@ -1291,7 +1121,6 @@ def kron(a, b): |
1291 | 1121 | "det", |
1292 | 1122 | "eig", |
1293 | 1123 | "eigh", |
1294 | | - "qr", |
1295 | 1124 | "svd", |
1296 | 1125 | "lstsq", |
1297 | 1126 | "matrix_power", |
|
0 commit comments