|
| 1 | +#:include "common.fypp" |
| 2 | +#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES |
| 3 | +submodule (stdlib_linalg) stdlib_linalg_matrix_functions |
| 4 | + use stdlib_constants |
| 5 | + use stdlib_linalg_constants |
| 6 | + use stdlib_linalg_blas, only: gemm |
| 7 | + use stdlib_linalg_lapack, only: gesv |
| 8 | + use stdlib_linalg_lapack_aux, only: handle_gesv_info |
| 9 | + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, & |
| 10 | + LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR |
| 11 | + implicit none |
| 12 | + |
| 13 | + character(len=*), parameter :: this = "matrix_exponential" |
| 14 | + |
| 15 | +contains |
| 16 | + |
| 17 | + #:for rk,rt,ri in RC_KINDS_TYPES |
| 18 | + module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E) |
| 19 | + !> Input matrix A(n, n). |
| 20 | + ${rt}$, intent(in) :: A(:, :) |
| 21 | + !> [optional] Order of the Pade approximation. |
| 22 | + integer(ilp), optional, intent(in) :: order |
| 23 | + !> Exponential of the input matrix E = exp(A). |
| 24 | + ${rt}$, allocatable :: E(:, :) |
| 25 | + |
| 26 | + E = A ; call stdlib_linalg_${ri}$_expm_inplace(E, order) |
| 27 | + end function |
| 28 | + |
| 29 | + module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err) |
| 30 | + !> Input matrix A(n, n). |
| 31 | + ${rt}$, intent(in) :: A(:, :) |
| 32 | + !> [optional] Order of the Pade approximation. |
| 33 | + integer(ilp), optional, intent(in) :: order |
| 34 | + !> [optional] State return flag. |
| 35 | + type(linalg_state_type), optional, intent(out) :: err |
| 36 | + !> Exponential of the input matrix E = exp(A). |
| 37 | + ${rt}$, intent(out) :: E(:, :) |
| 38 | + |
| 39 | + type(linalg_state_type) :: err0 |
| 40 | + integer(ilp) :: lda, n, lde, ne |
| 41 | + |
| 42 | + ! Check E sizes |
| 43 | + lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp) |
| 44 | + lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp) |
| 45 | + |
| 46 | + if (lda<1 .or. n<1 .or. lda/=n .or. lde/=n .or. ne/=n) then |
| 47 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR, & |
| 48 | + 'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', & |
| 49 | + ' E must be square (lde=', lde, ', ne=', ne, ')') |
| 50 | + else |
| 51 | + E(:n, :n) = A(:n, :n) ; call stdlib_linalg_${ri}$_expm_inplace(E, order, err0) |
| 52 | + endif |
| 53 | + |
| 54 | + ! Process output and return |
| 55 | + call linalg_error_handling(err0,err) |
| 56 | + |
| 57 | + return |
| 58 | + end subroutine stdlib_linalg_${ri}$_expm |
| 59 | + |
| 60 | + module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err) |
| 61 | + !> Input matrix A(n, n) / Output matrix exponential. |
| 62 | + ${rt}$, intent(inout) :: A(:, :) |
| 63 | + !> [optional] Order of the Pade approximation. |
| 64 | + integer(ilp), optional, intent(in) :: order |
| 65 | + !> [optional] State return flag. |
| 66 | + type(linalg_state_type), optional, intent(out) :: err |
| 67 | + |
| 68 | + ! Internal variables. |
| 69 | + ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :) |
| 70 | + real(${rk}$) :: a_norm, c |
| 71 | + integer(ilp) :: m, n, ee, k, s, order_, i, j |
| 72 | + logical(lk) :: p |
| 73 | + type(linalg_state_type) :: err0 |
| 74 | + |
| 75 | + ! Deal with optional args. |
| 76 | + order_ = 10 ; if (present(order)) order_ = order |
| 77 | + |
| 78 | + ! Problem's dimension. |
| 79 | + m = size(A, dim=1, kind=ilp) ; n = size(A, dim=2, kind=ilp) |
| 80 | + |
| 81 | + if (m /= n) then |
| 82 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n]) |
| 83 | + else if (order_ < 0) then |
| 84 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation & |
| 85 | + needs to be positive, order=', order_) |
| 86 | + else |
| 87 | + ! Compute the L-infinity norm. |
| 88 | + a_norm = mnorm(A, "inf") |
| 89 | + |
| 90 | + ! Determine scaling factor for the matrix. |
| 91 | + ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1 |
| 92 | + s = max(0, ee+1) |
| 93 | + |
| 94 | + ! Scale the input matrix & initialize polynomial. |
| 95 | + A2 = A/2.0_${rk}$**s ; X = A2 |
| 96 | + |
| 97 | + ! First step of the Pade approximation. |
| 98 | + c = 0.5_${rk}$ |
| 99 | + allocate (Q, source=A2) ; A = A2 |
| 100 | + do concurrent(i=1:n, j=1:n) |
| 101 | + A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j) |
| 102 | + Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j) |
| 103 | + enddo |
| 104 | + |
| 105 | + ! Iteratively compute the Pade approximation. |
| 106 | + block |
| 107 | + ${rt}$, allocatable :: X_tmp(:, :) |
| 108 | + p = .true. |
| 109 | + do k = 2, order_ |
| 110 | + c = c * (order_ - k + 1) / (k * (2*order_ - k + 1)) |
| 111 | + X_tmp = X |
| 112 | + #:if rt.startswith('complex') |
| 113 | + call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n) |
| 114 | + #:else |
| 115 | + call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n) |
| 116 | + #:endif |
| 117 | + do concurrent(i=1:n, j=1:n) |
| 118 | + A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X |
| 119 | + enddo |
| 120 | + if (p) then |
| 121 | + do concurrent(i=1:n, j=1:n) |
| 122 | + Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X |
| 123 | + enddo |
| 124 | + else |
| 125 | + do concurrent(i=1:n, j=1:n) |
| 126 | + Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X |
| 127 | + enddo |
| 128 | + endif |
| 129 | + p = .not. p |
| 130 | + enddo |
| 131 | + end block |
| 132 | + |
| 133 | + block |
| 134 | + integer(ilp) :: ipiv(n), info |
| 135 | + call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E |
| 136 | + call handle_gesv_info(this, info, n, n, n, err0) |
| 137 | + end block |
| 138 | + |
| 139 | + ! Matrix squaring. |
| 140 | + block |
| 141 | + ${rt}$, allocatable :: E_tmp(:, :) |
| 142 | + do k = 1, s |
| 143 | + E_tmp = A |
| 144 | + #:if rt.startswith('complex') |
| 145 | + call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n) |
| 146 | + #:else |
| 147 | + call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n) |
| 148 | + #:endif |
| 149 | + enddo |
| 150 | + end block |
| 151 | + endif |
| 152 | + |
| 153 | + call linalg_error_handling(err0, err) |
| 154 | + |
| 155 | + return |
| 156 | + end subroutine stdlib_linalg_${ri}$_expm_inplace |
| 157 | + #:endfor |
| 158 | + |
| 159 | +end submodule stdlib_linalg_matrix_functions |
0 commit comments