@@ -15,22 +15,62 @@ submodule (stdlib_linalg) stdlib_linalg_matrix_functions
1515contains
1616
1717 #:for rk,rt,ri in RC_KINDS_TYPES
18- module function stdlib_expm_${ri}$(A, order, err) result(E)
18+ module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
19+ !> Input matrix A(n, n).
20+ ${rt}$, intent(in) :: A(:, :)
21+ !> [optional] Order of the Pade approximation.
22+ integer(ilp), optional, intent(in) :: order
23+ !> Exponential of the input matrix E = exp(A).
24+ ${rt}$, allocatable :: E(:, :)
25+
26+ E = A ; call stdlib_linalg_${ri}$_expm_inplace(E, order)
27+ end function
28+
29+ module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
1930 !> Input matrix A(n, n).
2031 ${rt}$, intent(in) :: A(:, :)
2132 !> [optional] Order of the Pade approximation.
2233 integer(ilp), optional, intent(in) :: order
2334 !> [optional] State return flag.
2435 type(linalg_state_type), optional, intent(out) :: err
2536 !> Exponential of the input matrix E = exp(A).
26- ${rt}$, allocatable :: E(:, :)
37+ ${rt}$, intent(out) :: E(:, :)
38+
39+ type(linalg_state_type) :: err0
40+ integer(ilp) :: lda, n, lde, ne
41+
42+ ! Check E sizes
43+ lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp)
44+ lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp)
45+
46+ if (lda<1 .or. n<1 .or. lda<n .or. lde<n .or. ne<n) then
47+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR, &
48+ 'invalid matrix sizes: A=',[lda,n], &
49+ ' E=',[lde,ne])
50+ else
51+ E(:n, :n) = A(:n, :n) ; call stdlib_linalg_${ri}$_expm_inplace(E, order, err)
52+ endif
53+
54+ ! Process output and return
55+ call linalg_error_handling(err0,err)
56+
57+ return
58+ end subroutine stdlib_linalg_${ri}$_expm
59+
60+ module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
61+ !> Input matrix A(n, n) / Output matrix exponential.
62+ ${rt}$, intent(inout) :: A(:, :)
63+ !> [optional] Order of the Pade approximation.
64+ integer(ilp), optional, intent(in) :: order
65+ !> [optional] State return flag.
66+ type(linalg_state_type), optional, intent(out) :: err
2767
2868 ! Internal variables.
29- ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
30- real(${rk}$) :: a_norm, c
31- integer(ilp) :: m, n, ee, k, s, order_, i, j
32- logical(lk) :: p
33- type(linalg_state_type) :: err0
69+ ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
70+ real(${rk}$) :: a_norm, c
71+ integer(ilp) :: m, n, ee, k, s, order_, i, j
72+ logical(lk) :: p
73+ type(linalg_state_type) :: err0
3474
3575 ! Deal with optional args.
3676 order_ = 10 ; if (present(order)) order_ = order
@@ -40,82 +80,80 @@ contains
4080
4181 if (m /= n) then
4282 err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
43- call linalg_error_handling(err0, err)
44- return
4583 else if (order_ < 0) then
4684 err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
4785 needs to be positive, order=', order_)
48- call linalg_error_handling(err0, err)
49- return
50- endif
86+ else
87+ ! Compute the L-infinity norm.
88+ a_norm = mnorm(A, "inf")
5189
52- ! Compute the L-infinity norm.
53- a_norm = mnorm(A, "inf")
54-
55- ! Determine scaling factor for the matrix.
56- ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
57- s = max(0, ee+1)
58-
59- ! Scale the input matrix & initialize polynomial.
60- A2 = A/2.0_${rk}$**s ; X = A2
61-
62- ! First step of the Pade approximation.
63- c = 0.5_${rk}$
64- allocate (E, source=A2) ; allocate (Q, source=A2)
65- do concurrent(i=1:n, j=1:n)
66- E(i, j) = merge(1.0_${rk}$ + c*E(i, j), c*E(i, j), i == j)
67- Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
68- enddo
69-
70- ! Iteratively compute the Pade approximation.
71- block
72- ${rt}$ :: X_tmp(n, n)
73- p = .true.
74- do k = 2, order_
75- c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
76- X_tmp = X
77- #:if rt.startswith('complex')
78- call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
79- #:else
80- call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
81- #:endif
82- do concurrent(i=1:n, j=1:n)
83- E(i, j) = E(i, j) + c*X(i, j) ! E = E + c*X
84- enddo
85- if (p) then
86- do concurrent(i=1:n, j=1:n)
87- Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
88- enddo
89- else
90+ ! Determine scaling factor for the matrix.
91+ ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
92+ s = max(0, ee+1)
93+
94+ ! Scale the input matrix & initialize polynomial.
95+ A2 = A/2.0_${rk}$**s ; X = A2
96+
97+ ! First step of the Pade approximation.
98+ c = 0.5_${rk}$
99+ allocate (Q, source=A2) ; A = A2
100+ do concurrent(i=1:n, j=1:n)
101+ A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j)
102+ Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
103+ enddo
104+
105+ ! Iteratively compute the Pade approximation.
106+ block
107+ ${rt}$ :: X_tmp(n, n)
108+ p = .true.
109+ do k = 2, order_
110+ c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+ X_tmp = X
112+ #:if rt.startswith('complex')
113+ call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+ #:else
115+ call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+ #:endif
90117 do concurrent(i=1:n, j=1:n)
91- Q (i, j) = Q (i, j) - c*X(i, j) ! Q = Q - c*X
118+ A (i, j) = A (i, j) + c*X(i, j) ! E = E + c*X
92119 enddo
93- endif
94- p = .not. p
95- enddo
96- end block
97-
98- block
99- integer(ilp) :: ipiv(n), info
100- call gesv(n, n, Q, n, ipiv, E, n, info) ! E = inv(Q) @ E
101- call handle_gesv_info(this, info, n, n, n, err0)
102- call linalg_error_handling(err0, err)
103- end block
104-
105- ! Matrix squaring.
106- block
107- ${rt}$ :: E_tmp(n, n)
108- do k = 1, s
109- E_tmp = E
110- #:if rt.startswith('complex')
111- call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, E, n)
112- #:else
113- call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, E, n)
114- #:endif
115- enddo
116- end block
120+ if (p) then
121+ do concurrent(i=1:n, j=1:n)
122+ Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
123+ enddo
124+ else
125+ do concurrent(i=1:n, j=1:n)
126+ Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+ enddo
128+ endif
129+ p = .not. p
130+ enddo
131+ end block
132+
133+ block
134+ integer(ilp) :: ipiv(n), info
135+ call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E
136+ call handle_gesv_info(this, info, n, n, n, err0)
137+ end block
138+
139+ ! Matrix squaring.
140+ block
141+ ${rt}$ :: E_tmp(n, n)
142+ do k = 1, s
143+ E_tmp = A
144+ #:if rt.startswith('complex')
145+ call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
146+ #:else
147+ call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
148+ #:endif
149+ enddo
150+ end block
151+ endif
152+
153+ call linalg_error_handling(err0, err)
154+
117155 return
118- end function stdlib_expm_ ${ri}$
156+ end subroutine stdlib_linalg_ ${ri}$_expm_inplace
119157 #:endfor
120158
121159end submodule stdlib_linalg_matrix_functions
0 commit comments