Skip to content

Commit a381d0b

Browse files
committed
subroutine driver and interface (in-place and out-of-place)
1 parent d97043d commit a381d0b

File tree

3 files changed

+177
-84
lines changed

3 files changed

+177
-84
lines changed

src/stdlib_linalg.fypp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module stdlib_linalg
2828
public :: eigh
2929
public :: eigvals
3030
public :: eigvalsh
31-
public :: expm
31+
public :: expm, matrix_exp
3232
public :: eye
3333
public :: inv
3434
public :: invert
@@ -1713,19 +1713,74 @@ module stdlib_linalg
17131713
!! ```
17141714
!!
17151715
#:for rk,rt,ri in RC_KINDS_TYPES
1716-
module function stdlib_expm_${ri}$(A, order, err) result(E)
1716+
module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
17171717
!> Input matrix a(n, n).
17181718
${rt}$, intent(in) :: A(:, :)
17191719
!> [optional] Order of the Pade approximation (default `order=10`)
17201720
integer(ilp), optional, intent(in) :: order
1721-
!> [optional] State return flag. On error, if not requested, the code will stop.
1722-
type(linalg_state_type), optional, intent(out) :: err
17231721
!> Exponential of the input matrix E = exp(A).
17241722
${rt}$, allocatable :: E(:, :)
1725-
end function stdlib_expm_${ri}$
1723+
end function stdlib_linalg_${ri}$_expm_fun
17261724
#:endfor
17271725
end interface expm
17281726

1727+
!> Matrix exponential: subroutine interface
1728+
interface matrix_exp
1729+
!! version : experimental
1730+
!!
1731+
!! Computes the exponential of a matrix using a rational Pade approximation.
1732+
!! ([Specification](../page/specs/stdlib_linalg.html#matrix_exp))
1733+
!!
1734+
!! ### Description
1735+
!!
1736+
!! This interface provides methods for computing the exponential of a matrix
1737+
!! represented as a standard Fortran rank-2 array. Supported data types include
1738+
!! `real` and `complex`.
1739+
!!
1740+
!! By default, the order of the Pade approximation is set to 10. It can be changed
1741+
!! via the `order` argument which must be non-negative.
1742+
!!
1743+
!! If the input matrix is non-square or the order of the Pade approximation is
1744+
!! negative, the function returns an error state.
1745+
!!
1746+
!! ### Example
1747+
!!
1748+
!! ```fortran
1749+
!! real(dp) :: A(3, 3), E(3, 3)
1750+
!!
1751+
!! A = reshape([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3])
1752+
!!
1753+
!! ! Default Pade approximation of the matrix exponential.
1754+
!! call matrix_exp(A, E) ! Out-of-place
1755+
!! ! call matrix_exp(A) for in-place computation.
1756+
!!
1757+
!! ! Pade approximation with specified order.
1758+
!! call matrix_exp(A, E, order=12)
1759+
!! ```
1760+
!!
1761+
#:for rk,rt,ri in RC_KINDS_TYPES
1762+
module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
1763+
!> Input matrix A(n, n) / Output matrix E = exp(A)
1764+
${rt}$, intent(inout) :: A(:, :)
1765+
!> [optional] Order of the Pade approximation (default `order=10`)
1766+
integer(ilp), optional, intent(in) :: order
1767+
!> [optional] Error handling.
1768+
type(linalg_state_type), optional, intent(out) :: err
1769+
end subroutine stdlib_linalg_${ri}$_expm_inplace
1770+
1771+
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
1772+
!> Input matrix A(n, n)
1773+
${rt}$, intent(in) :: A(:, :)
1774+
!> Output matrix exponential E = exp(A)
1775+
${rt}$, intent(out) :: E(:, :)
1776+
!> [optional] Order of the Pade approximation (default `order=10`)
1777+
integer(ilp), optional, intent(in) :: order
1778+
!> [optional] Error handling.
1779+
type(linalg_state_type), optional, intent(out) :: err
1780+
end subroutine stdlib_linalg_${ri}$_expm
1781+
#:endfor
1782+
end interface matrix_exp
1783+
17291784
contains
17301785

17311786

src/stdlib_linalg_matrix_functions.fypp

Lines changed: 114 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,62 @@ submodule (stdlib_linalg) stdlib_linalg_matrix_functions
1515
contains
1616

1717
#:for rk,rt,ri in RC_KINDS_TYPES
18-
module function stdlib_expm_${ri}$(A, order, err) result(E)
18+
module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
19+
!> Input matrix A(n, n).
20+
${rt}$, intent(in) :: A(:, :)
21+
!> [optional] Order of the Pade approximation.
22+
integer(ilp), optional, intent(in) :: order
23+
!> Exponential of the input matrix E = exp(A).
24+
${rt}$, allocatable :: E(:, :)
25+
26+
E = A ; call stdlib_linalg_${ri}$_expm_inplace(E, order)
27+
end function
28+
29+
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
1930
!> Input matrix A(n, n).
2031
${rt}$, intent(in) :: A(:, :)
2132
!> [optional] Order of the Pade approximation.
2233
integer(ilp), optional, intent(in) :: order
2334
!> [optional] State return flag.
2435
type(linalg_state_type), optional, intent(out) :: err
2536
!> Exponential of the input matrix E = exp(A).
26-
${rt}$, allocatable :: E(:, :)
37+
${rt}$, intent(out) :: E(:, :)
38+
39+
type(linalg_state_type) :: err0
40+
integer(ilp) :: lda, n, lde, ne
41+
42+
! Check E sizes
43+
lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp)
44+
lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp)
45+
46+
if (lda<1 .or. n<1 .or. lda<n .or. lde<n .or. ne<n) then
47+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR, &
48+
'invalid matrix sizes: A=',[lda,n], &
49+
' E=',[lde,ne])
50+
else
51+
E(:n, :n) = A(:n, :n) ; call stdlib_linalg_${ri}$_expm_inplace(E, order, err)
52+
endif
53+
54+
! Process output and return
55+
call linalg_error_handling(err0,err)
56+
57+
return
58+
end subroutine stdlib_linalg_${ri}$_expm
59+
60+
module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
61+
!> Input matrix A(n, n) / Output matrix exponential.
62+
${rt}$, intent(inout) :: A(:, :)
63+
!> [optional] Order of the Pade approximation.
64+
integer(ilp), optional, intent(in) :: order
65+
!> [optional] State return flag.
66+
type(linalg_state_type), optional, intent(out) :: err
2767

2868
! Internal variables.
29-
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
30-
real(${rk}$) :: a_norm, c
31-
integer(ilp) :: m, n, ee, k, s, order_, i, j
32-
logical(lk) :: p
33-
type(linalg_state_type) :: err0
69+
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
70+
real(${rk}$) :: a_norm, c
71+
integer(ilp) :: m, n, ee, k, s, order_, i, j
72+
logical(lk) :: p
73+
type(linalg_state_type) :: err0
3474

3575
! Deal with optional args.
3676
order_ = 10 ; if (present(order)) order_ = order
@@ -40,82 +80,80 @@ contains
4080

4181
if (m /= n) then
4282
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
43-
call linalg_error_handling(err0, err)
44-
return
4583
else if (order_ < 0) then
4684
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
4785
needs to be positive, order=', order_)
48-
call linalg_error_handling(err0, err)
49-
return
50-
endif
86+
else
87+
! Compute the L-infinity norm.
88+
a_norm = mnorm(A, "inf")
5189

52-
! Compute the L-infinity norm.
53-
a_norm = mnorm(A, "inf")
54-
55-
! Determine scaling factor for the matrix.
56-
ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
57-
s = max(0, ee+1)
58-
59-
! Scale the input matrix & initialize polynomial.
60-
A2 = A/2.0_${rk}$**s ; X = A2
61-
62-
! First step of the Pade approximation.
63-
c = 0.5_${rk}$
64-
allocate (E, source=A2) ; allocate (Q, source=A2)
65-
do concurrent(i=1:n, j=1:n)
66-
E(i, j) = merge(1.0_${rk}$ + c*E(i, j), c*E(i, j), i == j)
67-
Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
68-
enddo
69-
70-
! Iteratively compute the Pade approximation.
71-
block
72-
${rt}$ :: X_tmp(n, n)
73-
p = .true.
74-
do k = 2, order_
75-
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
76-
X_tmp = X
77-
#:if rt.startswith('complex')
78-
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
79-
#:else
80-
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
81-
#:endif
82-
do concurrent(i=1:n, j=1:n)
83-
E(i, j) = E(i, j) + c*X(i, j) ! E = E + c*X
84-
enddo
85-
if (p) then
86-
do concurrent(i=1:n, j=1:n)
87-
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
88-
enddo
89-
else
90+
! Determine scaling factor for the matrix.
91+
ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
92+
s = max(0, ee+1)
93+
94+
! Scale the input matrix & initialize polynomial.
95+
A2 = A/2.0_${rk}$**s ; X = A2
96+
97+
! First step of the Pade approximation.
98+
c = 0.5_${rk}$
99+
allocate (Q, source=A2) ; A = A2
100+
do concurrent(i=1:n, j=1:n)
101+
A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j)
102+
Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
103+
enddo
104+
105+
! Iteratively compute the Pade approximation.
106+
block
107+
${rt}$ :: X_tmp(n, n)
108+
p = .true.
109+
do k = 2, order_
110+
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+
X_tmp = X
112+
#:if rt.startswith('complex')
113+
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+
#:else
115+
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+
#:endif
90117
do concurrent(i=1:n, j=1:n)
91-
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
118+
A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
92119
enddo
93-
endif
94-
p = .not. p
95-
enddo
96-
end block
97-
98-
block
99-
integer(ilp) :: ipiv(n), info
100-
call gesv(n, n, Q, n, ipiv, E, n, info) ! E = inv(Q) @ E
101-
call handle_gesv_info(this, info, n, n, n, err0)
102-
call linalg_error_handling(err0, err)
103-
end block
104-
105-
! Matrix squaring.
106-
block
107-
${rt}$ :: E_tmp(n, n)
108-
do k = 1, s
109-
E_tmp = E
110-
#:if rt.startswith('complex')
111-
call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, E, n)
112-
#:else
113-
call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, E, n)
114-
#:endif
115-
enddo
116-
end block
120+
if (p) then
121+
do concurrent(i=1:n, j=1:n)
122+
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
123+
enddo
124+
else
125+
do concurrent(i=1:n, j=1:n)
126+
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+
enddo
128+
endif
129+
p = .not. p
130+
enddo
131+
end block
132+
133+
block
134+
integer(ilp) :: ipiv(n), info
135+
call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E
136+
call handle_gesv_info(this, info, n, n, n, err0)
137+
end block
138+
139+
! Matrix squaring.
140+
block
141+
${rt}$ :: E_tmp(n, n)
142+
do k = 1, s
143+
E_tmp = A
144+
#:if rt.startswith('complex')
145+
call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
146+
#:else
147+
call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
148+
#:endif
149+
enddo
150+
end block
151+
endif
152+
153+
call linalg_error_handling(err0, err)
154+
117155
return
118-
end function stdlib_expm_${ri}$
156+
end subroutine stdlib_linalg_${ri}$_expm_inplace
119157
#:endfor
120158

121159
end submodule stdlib_linalg_matrix_functions

test/linalg/test_linalg_expm.fypp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
module test_linalg_expm
55
use testdrive, only: error_type, check, new_unittest, unittest_type
66
use stdlib_linalg_constants
7-
use stdlib_linalg, only: expm, eye, norm
7+
use stdlib_linalg, only: expm, eye, norm, matrix_exp
88
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
99
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
1010

@@ -82,13 +82,13 @@ module test_linalg_expm
8282
enddo
8383

8484
! Compute matrix exponential.
85-
E = expm(A, order=-1, err=err)
85+
call matrix_exp(A, E, order=-1, err=err)
8686
! Check result.
8787
call check(error, err%error(), "Negative Pade order")
8888
if (allocated(error)) return
8989

9090
! Compute matrix exponential.
91-
E = expm(A(:n, :n-1), err=err)
91+
call matrix_exp(A(:n, :n-1), E, err=err)
9292
! Check result.
9393
call check(error, err%error(), "Invalid matrix size")
9494
if (allocated(error)) return

0 commit comments

Comments
 (0)