Skip to content

Commit fa36f33

Browse files
committed
Improved implementation + error handling.
1 parent b0a74b1 commit fa36f33

File tree

1 file changed

+88
-26
lines changed

1 file changed

+88
-26
lines changed
Lines changed: 88 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,130 @@
11
#:include "common.fypp"
22
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
33
submodule (stdlib_linalg) stdlib_linalg_matrix_functions
4-
use stdlib_linalg_constants
5-
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
4+
use stdlib_linalg_constants
5+
use stdlib_linalg_lapack, only: gesv
6+
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
67
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
7-
implicit none
8+
implicit none
89

9-
contains
10+
#:for rk, rt, ri in (REAL_KINDS_TYPES)
11+
${rt}$, parameter :: zero_${ri}$ = 0._${rk}$
12+
${rt}$, parameter :: one_${ri}$ = 1._${rk}$
13+
#:endfor
14+
#:for rk, rt, ri in (CMPLX_KINDS_TYPES)
15+
${rt}$, parameter :: zero_${ri}$ = (0._${rk}$, 0._${rk}$)
16+
${rt}$, parameter :: one_${ri}$ = (1._${rk}$, 0._${rk}$)
17+
#:endfor
18+
19+
contains
1020

1121
#:for rk,rt,ri in RC_KINDS_TYPES
12-
module function expm_${ri}$(A, order) result(E)
22+
module function stdlib_expm_${ri}$(A, order, err) result(E)
23+
!> Input matrix A(n, n).
1324
${rt}$, intent(in) :: A(:, :)
25+
!> [optional] Order of the Pade approximation.
1426
integer(ilp), optional, intent(in) :: order
27+
!> [optional] State return flag.
28+
type(linalg_state_type), optional, intent(out) :: err
29+
!> Exponential of the input matrix E = exp(A).
1530
${rt}$, allocatable :: E(:, :)
1631

32+
! Internal variables.
1733
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
18-
real(${rk}$) :: a_norm, c
19-
integer(ilp) :: n, ee, k, s
20-
logical(lk) :: p
21-
integer(ilp) :: p_order
34+
real(${rk}$) :: a_norm, c
35+
integer(ilp) :: m, n, ee, k, s, order_, i, j
36+
logical(lk) :: p
37+
character(len=*), parameter :: this = "expm"
38+
type(linalg_state_type) :: err0
2239

2340
! Deal with optional args.
24-
p_order = 10 ; if (present(order)) p_order = order
41+
order_ = 10 ; if (present(order)) order_ = order
42+
43+
! Problem's dimension.
44+
m = size(A, 1) ; n = size(A, 2)
2545

26-
n = size(A, 1)
46+
if (m /= n) then
47+
err = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
48+
call linalg_error_handling(err0, err)
49+
else if (order_ < 0) then
50+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
51+
needs to be positive, order=', order_)
52+
call linalg_error_handling(err0, err)
53+
endif
2754

2855
! Compute the L-infinity norm.
2956
a_norm = mnorm(A, "inf")
3057

3158
! Determine scaling factor for the matrix.
3259
ee = int(log(a_norm) / log(2.0_${rk}$)) + 1
33-
s = max(0, ee+1)
60+
s = max(0, ee+1)
3461

3562
! Scale the input matrix & initialize polynomial.
36-
A2 = A / 2.0_${rk}$**s
37-
X = A2
63+
A2 = A/2.0_${rk}$**s ; X = A2
3864

39-
! Initialize P & Q and add first step.
65+
! First step of the Pade approximation.
4066
c = 0.5_${rk}$
41-
E = eye(n, mold=1.0_${rk}$) ; E = E + c*A2
42-
43-
Q = eye(n, mold=1.0_${rk}$) ; Q = Q - c*A2
67+
allocate (E, source=A2) ; allocate (Q, source=A2)
68+
do concurrent(i=1:n, j=1:n)
69+
E(i, j) = c*E(i, j) ; if (i == j) E(i, j) = 1.0_${rk}$ + E(i, j) ! E = I + c*A2
70+
Q(i, j) = -c*Q(i, j) ; if (i == j) Q(i, j) = 1.0_${rk}$ + Q(i, j) ! Q = I - c*A2
71+
enddo
4472

4573
! Iteratively compute the Pade approximation.
4674
p = .true.
47-
do k = 2, p_order
48-
c = c*(p_order - k + 1) / (k * (2*p_order - k + 1))
75+
do k = 2, order_
76+
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
4977
X = matmul(A2, X)
50-
E = E + c*X
78+
do concurrent(i=1:n, j=1:n)
79+
E(i, j) = E(i, j) + c*X(i, j) ! E = E + c*X
80+
enddo
5181
if (p) then
52-
Q = Q + c*X
82+
do concurrent(i=1:n, j=1:n)
83+
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
84+
enddo
5385
else
54-
Q = Q - c*X
86+
do concurrent(i=1:n, j=1:n)
87+
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
88+
enddo
5589
endif
5690
p = .not. p
5791
enddo
5892

59-
E = matmul(inv(Q), E)
93+
block
94+
integer(ilp) :: ipiv(n), info
95+
call gesv(n, n, Q, n, ipiv, E, n, info) ! E = inv(Q) @ E
96+
call handle_gesv_info(info, n, n, n, err0)
97+
call linalg_error_handling(err0, err)
98+
end block
99+
100+
! This loop should eventually be replaced by a fast matrix_power function.
60101
do k = 1, s
61102
E = matmul(E, E)
62103
enddo
63-
64104
return
65-
end function
105+
contains
106+
elemental subroutine handle_gesv_info(info,lda,n,nrhs,err)
107+
integer(ilp), intent(in) :: info,lda,n,nrhs
108+
type(linalg_state_type), intent(out) :: err
109+
! Process output
110+
select case (info)
111+
case (0)
112+
! Success
113+
case (-1)
114+
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size n=',n)
115+
case (-2)
116+
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid rhs size n=',nrhs)
117+
case (-4)
118+
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid matrix size a=',[lda,n])
119+
case (-7)
120+
err = linalg_state_type(this,LINALG_ERROR,'invalid matrix size a=',[lda,n])
121+
case (1:)
122+
err = linalg_state_type(this,LINALG_ERROR,'singular matrix')
123+
case default
124+
err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
125+
end select
126+
end subroutine handle_gesv_info
127+
end function stdlib_expm_${ri}$
66128
#:endfor
67129

68130
end submodule stdlib_linalg_matrix_functions

0 commit comments

Comments
 (0)