|
| 1 | +#:include "common.fypp" |
| 2 | +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX, REAL_INIT)) |
| 3 | +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX, CMPLX_INIT)) |
| 4 | +#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES |
| 5 | +submodule (stdlib_linalg) stdlib_linalg_matrix_functions |
| 6 | + use stdlib_constants |
| 7 | + use stdlib_linalg_constants |
| 8 | + use stdlib_linalg_blas, only: gemm |
| 9 | + use stdlib_linalg_lapack, only: gesv, lacpy |
| 10 | + use stdlib_linalg_lapack_aux, only: handle_gesv_info |
| 11 | + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, & |
| 12 | + LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR |
| 13 | + implicit none(type, external) |
| 14 | + |
| 15 | + character(len=*), parameter :: this = "matrix_exponential" |
| 16 | + |
| 17 | +contains |
| 18 | + |
| 19 | + #:for k,t,s, i in RC_KINDS_TYPES |
| 20 | + module function stdlib_linalg_${i}$_expm_fun(A, order) result(E) |
| 21 | + !> Input matrix A(n, n). |
| 22 | + ${t}$, intent(in) :: A(:, :) |
| 23 | + !> [optional] Order of the Pade approximation. |
| 24 | + integer(ilp), optional, intent(in) :: order |
| 25 | + !> Exponential of the input matrix E = exp(A). |
| 26 | + ${t}$, allocatable :: E(:, :) |
| 27 | + |
| 28 | + E = A |
| 29 | + call stdlib_linalg_${i}$_expm_inplace(E, order) |
| 30 | + end function stdlib_linalg_${i}$_expm_fun |
| 31 | + |
| 32 | + module subroutine stdlib_linalg_${i}$_expm(A, E, order, err) |
| 33 | + !> Input matrix A(n, n). |
| 34 | + ${t}$, intent(in) :: A(:, :) |
| 35 | + !> Exponential of the input matrix E = exp(A). |
| 36 | + ${t}$, intent(out) :: E(:, :) |
| 37 | + !> [optional] Order of the Pade approximation. |
| 38 | + integer(ilp), optional, intent(in) :: order |
| 39 | + !> [optional] State return flag. |
| 40 | + type(linalg_state_type), optional, intent(out) :: err |
| 41 | + |
| 42 | + type(linalg_state_type) :: err0 |
| 43 | + integer(ilp) :: lda, n, lde, ne |
| 44 | + |
| 45 | + ! Check E sizes |
| 46 | + lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp) |
| 47 | + lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp) |
| 48 | + |
| 49 | + if (lda<1 .or. n<1 .or. lda/=n .or. lde/=n .or. ne/=n) then |
| 50 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR, & |
| 51 | + 'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', & |
| 52 | + ' E must be square (lde=', lde, ', ne=', ne, ')') |
| 53 | + else |
| 54 | + call lacpy("n", n, n, A, n, E, n) ! E = A |
| 55 | + call stdlib_linalg_${i}$_expm_inplace(E, order, err0) |
| 56 | + endif |
| 57 | + |
| 58 | + ! Process output and return |
| 59 | + call linalg_error_handling(err0,err) |
| 60 | + |
| 61 | + return |
| 62 | + end subroutine stdlib_linalg_${i}$_expm |
| 63 | + |
| 64 | + module subroutine stdlib_linalg_${i}$_expm_inplace(A, order, err) |
| 65 | + !> Input matrix A(n, n) / Output matrix exponential. |
| 66 | + ${t}$, intent(inout) :: A(:, :) |
| 67 | + !> [optional] Order of the Pade approximation. |
| 68 | + integer(ilp), optional, intent(in) :: order |
| 69 | + !> [optional] State return flag. |
| 70 | + type(linalg_state_type), optional, intent(out) :: err |
| 71 | + |
| 72 | + ! Internal variables. |
| 73 | + ${t}$ :: A2(size(A, 1), size(A, 2)), Q(size(A, 1), size(A, 2)) |
| 74 | + ${t}$ :: X(size(A, 1), size(A, 2)), X_tmp(size(A, 1), size(A, 2)) |
| 75 | + real(${k}$) :: a_norm, c |
| 76 | + integer(ilp) :: m, n, ee, k, s, order_, i, j |
| 77 | + logical(lk) :: p |
| 78 | + type(linalg_state_type) :: err0 |
| 79 | + |
| 80 | + ! Deal with optional args. |
| 81 | + order_ = 10 ; if (present(order)) order_ = order |
| 82 | + |
| 83 | + ! Problem's dimension. |
| 84 | + m = size(A, dim=1, kind=ilp) ; n = size(A, dim=2, kind=ilp) |
| 85 | + |
| 86 | + if (m /= n) then |
| 87 | + err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n]) |
| 88 | + else if (order_ < 0) then |
| 89 | + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation & |
| 90 | + needs to be positive, order=', order_) |
| 91 | + else |
| 92 | + ! Compute the L-infinity norm. |
| 93 | + a_norm = mnorm(A, "inf") |
| 94 | + |
| 95 | + ! Determine scaling factor for the matrix. |
| 96 | + ee = int(log(a_norm) / log2_${k}$, kind=ilp) + 1 |
| 97 | + s = max(0, ee+1) |
| 98 | + |
| 99 | + ! Scale the input matrix & initialize polynomial. |
| 100 | + A2 = A/2.0_${k}$**s |
| 101 | + call lacpy("n", n, n, A2, n, X, n) ! X = A2 |
| 102 | + |
| 103 | + ! First step of the Pade approximation. |
| 104 | + c = 0.5_${k}$ |
| 105 | + do concurrent(i=1:n, j=1:n) |
| 106 | + A(i, j) = merge(1.0_${k}$ + c*A2(i, j), c*A2(i, j), i == j) |
| 107 | + Q(i, j) = merge(1.0_${k}$ - c*A2(i, j), -c*A2(i, j), i == j) |
| 108 | + enddo |
| 109 | + |
| 110 | + ! Iteratively compute the Pade approximation. |
| 111 | + p = .true. |
| 112 | + do k = 2, order_ |
| 113 | + c = c * (order_ - k + 1) / (k * (2*order_ - k + 1)) |
| 114 | + call lacpy("n", n, n, X, n, X_tmp, n) ! X_tmp = X |
| 115 | + call gemm("N", "N", n, n, n, one_${s}$, A2, n, X_tmp, n, zero_${s}$, X, n) |
| 116 | + do concurrent(i=1:n, j=1:n) |
| 117 | + A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X |
| 118 | + Q(i, j) = merge(Q(i, j) + c*X(i, j), Q(i, j) - c*X(i, j), p) |
| 119 | + enddo |
| 120 | + p = .not. p |
| 121 | + enddo |
| 122 | + |
| 123 | + block |
| 124 | + integer(ilp) :: ipiv(n), info |
| 125 | + call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E |
| 126 | + call handle_gesv_info(this, info, n, n, n, err0) |
| 127 | + end block |
| 128 | + |
| 129 | + ! Matrix squaring. |
| 130 | + do k = 1, s |
| 131 | + call lacpy("n", n, n, A, n, X, n) ! X = A |
| 132 | + call gemm("N", "N", n, n, n, one_${s}$, X, n, X, n, zero_${s}$, A, n) |
| 133 | + enddo |
| 134 | + endif |
| 135 | + |
| 136 | + call linalg_error_handling(err0, err) |
| 137 | + |
| 138 | + return |
| 139 | + end subroutine stdlib_linalg_${i}$_expm_inplace |
| 140 | + #:endfor |
| 141 | + |
| 142 | +end submodule stdlib_linalg_matrix_functions |
0 commit comments