Skip to content

Commit ead54b1

Browse files
committed
Implementation of the matrix exponential (function and subroutine)
1 parent eb01acd commit ead54b1

File tree

3 files changed

+262
-0
lines changed

3 files changed

+262
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(fppFiles
4646
stdlib_linalg_svd.fypp
4747
stdlib_linalg_cholesky.fypp
4848
stdlib_linalg_schur.fypp
49+
stdlib_linalg_matrix_functions.fypp
4950
stdlib_optval.fypp
5051
stdlib_selection.fypp
5152
stdlib_sorting.fypp

src/stdlib_linalg.fypp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ module stdlib_linalg
2828
public :: eigh
2929
public :: eigvals
3030
public :: eigvalsh
31+
public :: expm, matrix_exp
3132
public :: eye
3233
public :: inv
3334
public :: invert
@@ -1678,6 +1679,107 @@ module stdlib_linalg
16781679
#:endfor
16791680
end interface mnorm
16801681

1682+
!> Matrix exponential: function interface
1683+
interface expm
1684+
!! version : experimental
1685+
!!
1686+
!! Computes the exponential of a matrix using a rational Pade approximation.
1687+
!! ([Specification](../page/specs/stdlib_linalg.html#expm))
1688+
!!
1689+
!! ### Description
1690+
!!
1691+
!! This interface provides methods for computing the exponential of a matrix
1692+
!! represented as a standard Fortran rank-2 array. Supported data types include
1693+
!! `real` and `complex`.
1694+
!!
1695+
!! By default, the order of the Pade approximation is set to 10. It can be changed
1696+
!! via the `order` argument which must be non-negative.
1697+
!!
1698+
!! If the input matrix is non-square or the order of the Pade approximation is
1699+
!! negative, the function returns an error state.
1700+
!!
1701+
!! ### Example
1702+
!!
1703+
!! ```fortran
1704+
!! real(dp) :: A(3, 3), E(3, 3)
1705+
!!
1706+
!! A = reshape([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3])
1707+
!!
1708+
!! ! Default Pade approximation of the matrix exponential.
1709+
!! E = expm(A)
1710+
!!
1711+
!! ! Pade approximation with specified order.
1712+
!! E = expm(A, order=12)
1713+
!! ```
1714+
!!
1715+
#:for rk,rt,ri in RC_KINDS_TYPES
1716+
module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
1717+
!> Input matrix a(n, n).
1718+
${rt}$, intent(in) :: A(:, :)
1719+
!> [optional] Order of the Pade approximation (default `order=10`)
1720+
integer(ilp), optional, intent(in) :: order
1721+
!> Exponential of the input matrix E = exp(A).
1722+
${rt}$, allocatable :: E(:, :)
1723+
end function stdlib_linalg_${ri}$_expm_fun
1724+
#:endfor
1725+
end interface expm
1726+
1727+
!> Matrix exponential: subroutine interface
1728+
interface matrix_exp
1729+
!! version : experimental
1730+
!!
1731+
!! Computes the exponential of a matrix using a rational Pade approximation.
1732+
!! ([Specification](../page/specs/stdlib_linalg.html#matrix_exp))
1733+
!!
1734+
!! ### Description
1735+
!!
1736+
!! This interface provides methods for computing the exponential of a matrix
1737+
!! represented as a standard Fortran rank-2 array. Supported data types include
1738+
!! `real` and `complex`.
1739+
!!
1740+
!! By default, the order of the Pade approximation is set to 10. It can be changed
1741+
!! via the `order` argument which must be non-negative.
1742+
!!
1743+
!! If the input matrix is non-square or the order of the Pade approximation is
1744+
!! negative, the function returns an error state.
1745+
!!
1746+
!! ### Example
1747+
!!
1748+
!! ```fortran
1749+
!! real(dp) :: A(3, 3), E(3, 3)
1750+
!!
1751+
!! A = reshape([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3])
1752+
!!
1753+
!! ! Default Pade approximation of the matrix exponential.
1754+
!! call matrix_exp(A, E) ! Out-of-place
1755+
!! ! call matrix_exp(A) for in-place computation.
1756+
!!
1757+
!! ! Pade approximation with specified order.
1758+
!! call matrix_exp(A, E, order=12)
1759+
!! ```
1760+
!!
1761+
#:for rk,rt,ri in RC_KINDS_TYPES
1762+
module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
1763+
!> Input matrix A(n, n) / Output matrix E = exp(A)
1764+
${rt}$, intent(inout) :: A(:, :)
1765+
!> [optional] Order of the Pade approximation (default `order=10`)
1766+
integer(ilp), optional, intent(in) :: order
1767+
!> [optional] Error handling.
1768+
type(linalg_state_type), optional, intent(out) :: err
1769+
end subroutine stdlib_linalg_${ri}$_expm_inplace
1770+
1771+
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
1772+
!> Input matrix A(n, n)
1773+
${rt}$, intent(in) :: A(:, :)
1774+
!> Output matrix exponential E = exp(A)
1775+
${rt}$, intent(out) :: E(:, :)
1776+
!> [optional] Order of the Pade approximation (default `order=10`)
1777+
integer(ilp), optional, intent(in) :: order
1778+
!> [optional] Error handling.
1779+
type(linalg_state_type), optional, intent(out) :: err
1780+
end subroutine stdlib_linalg_${ri}$_expm
1781+
#:endfor
1782+
end interface matrix_exp
16811783
contains
16821784

16831785

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#:include "common.fypp"
2+
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
3+
submodule (stdlib_linalg) stdlib_linalg_matrix_functions
4+
use stdlib_constants
5+
use stdlib_linalg_constants
6+
use stdlib_linalg_blas, only: gemm
7+
use stdlib_linalg_lapack, only: gesv
8+
use stdlib_linalg_lapack_aux, only: handle_gesv_info
9+
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
10+
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
11+
implicit none
12+
13+
character(len=*), parameter :: this = "matrix_exponential"
14+
15+
contains
16+
17+
#:for rk,rt,ri in RC_KINDS_TYPES
18+
module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
19+
!> Input matrix A(n, n).
20+
${rt}$, intent(in) :: A(:, :)
21+
!> [optional] Order of the Pade approximation.
22+
integer(ilp), optional, intent(in) :: order
23+
!> Exponential of the input matrix E = exp(A).
24+
${rt}$, allocatable :: E(:, :)
25+
26+
E = A ; call stdlib_linalg_${ri}$_expm_inplace(E, order)
27+
end function
28+
29+
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
30+
!> Input matrix A(n, n).
31+
${rt}$, intent(in) :: A(:, :)
32+
!> [optional] Order of the Pade approximation.
33+
integer(ilp), optional, intent(in) :: order
34+
!> [optional] State return flag.
35+
type(linalg_state_type), optional, intent(out) :: err
36+
!> Exponential of the input matrix E = exp(A).
37+
${rt}$, intent(out) :: E(:, :)
38+
39+
type(linalg_state_type) :: err0
40+
integer(ilp) :: lda, n, lde, ne
41+
42+
! Check E sizes
43+
lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp)
44+
lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp)
45+
46+
if (lda<1 .or. n<1 .or. lda/=n .or. lde/=n .or. ne/=n) then
47+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR, &
48+
'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', &
49+
' E must be square (lde=', lde, ', ne=', ne, ')')
50+
else
51+
E(:n, :n) = A(:n, :n) ; call stdlib_linalg_${ri}$_expm_inplace(E, order, err0)
52+
endif
53+
54+
! Process output and return
55+
call linalg_error_handling(err0,err)
56+
57+
return
58+
end subroutine stdlib_linalg_${ri}$_expm
59+
60+
module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
61+
!> Input matrix A(n, n) / Output matrix exponential.
62+
${rt}$, intent(inout) :: A(:, :)
63+
!> [optional] Order of the Pade approximation.
64+
integer(ilp), optional, intent(in) :: order
65+
!> [optional] State return flag.
66+
type(linalg_state_type), optional, intent(out) :: err
67+
68+
! Internal variables.
69+
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
70+
real(${rk}$) :: a_norm, c
71+
integer(ilp) :: m, n, ee, k, s, order_, i, j
72+
logical(lk) :: p
73+
type(linalg_state_type) :: err0
74+
75+
! Deal with optional args.
76+
order_ = 10 ; if (present(order)) order_ = order
77+
78+
! Problem's dimension.
79+
m = size(A, dim=1, kind=ilp) ; n = size(A, dim=2, kind=ilp)
80+
81+
if (m /= n) then
82+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
83+
else if (order_ < 0) then
84+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
85+
needs to be positive, order=', order_)
86+
else
87+
! Compute the L-infinity norm.
88+
a_norm = mnorm(A, "inf")
89+
90+
! Determine scaling factor for the matrix.
91+
ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
92+
s = max(0, ee+1)
93+
94+
! Scale the input matrix & initialize polynomial.
95+
A2 = A/2.0_${rk}$**s ; X = A2
96+
97+
! First step of the Pade approximation.
98+
c = 0.5_${rk}$
99+
allocate (Q, source=A2) ; A = A2
100+
do concurrent(i=1:n, j=1:n)
101+
A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j)
102+
Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
103+
enddo
104+
105+
! Iteratively compute the Pade approximation.
106+
block
107+
${rt}$, allocatable :: X_tmp(:, :)
108+
p = .true.
109+
do k = 2, order_
110+
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+
X_tmp = X
112+
#:if rt.startswith('complex')
113+
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+
#:else
115+
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+
#:endif
117+
do concurrent(i=1:n, j=1:n)
118+
A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
119+
enddo
120+
if (p) then
121+
do concurrent(i=1:n, j=1:n)
122+
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
123+
enddo
124+
else
125+
do concurrent(i=1:n, j=1:n)
126+
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+
enddo
128+
endif
129+
p = .not. p
130+
enddo
131+
end block
132+
133+
block
134+
integer(ilp) :: ipiv(n), info
135+
call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E
136+
call handle_gesv_info(this, info, n, n, n, err0)
137+
end block
138+
139+
! Matrix squaring.
140+
block
141+
${rt}$, allocatable :: E_tmp(:, :)
142+
do k = 1, s
143+
E_tmp = A
144+
#:if rt.startswith('complex')
145+
call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
146+
#:else
147+
call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
148+
#:endif
149+
enddo
150+
end block
151+
endif
152+
153+
call linalg_error_handling(err0, err)
154+
155+
return
156+
end subroutine stdlib_linalg_${ri}$_expm_inplace
157+
#:endfor
158+
159+
end submodule stdlib_linalg_matrix_functions

0 commit comments

Comments
 (0)