@@ -15,22 +15,62 @@ submodule (stdlib_linalg) stdlib_linalg_matrix_functions
15
15
contains
16
16
17
17
#:for rk,rt,ri in RC_KINDS_TYPES
18
- module function stdlib_expm_${ri}$(A, order, err) result(E)
18
+ module function stdlib_linalg_${ri}$_expm_fun(A, order) result(E)
19
+ !> Input matrix A(n, n).
20
+ ${rt}$, intent(in) :: A(:, :)
21
+ !> [optional] Order of the Pade approximation.
22
+ integer(ilp), optional, intent(in) :: order
23
+ !> Exponential of the input matrix E = exp(A).
24
+ ${rt}$, allocatable :: E(:, :)
25
+
26
+ E = A ; call stdlib_linalg_${ri}$_expm_inplace(E, order)
27
+ end function
28
+
29
+ module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
19
30
!> Input matrix A(n, n).
20
31
${rt}$, intent(in) :: A(:, :)
21
32
!> [optional] Order of the Pade approximation.
22
33
integer(ilp), optional, intent(in) :: order
23
34
!> [optional] State return flag.
24
35
type(linalg_state_type), optional, intent(out) :: err
25
36
!> Exponential of the input matrix E = exp(A).
26
- ${rt}$, allocatable :: E(:, :)
37
+ ${rt}$, intent(out) :: E(:, :)
38
+
39
+ type(linalg_state_type) :: err0
40
+ integer(ilp) :: lda, n, lde, ne
41
+
42
+ ! Check E sizes
43
+ lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp)
44
+ lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp)
45
+
46
+ if (lda<1 .or. n<1 .or. lda<n .or. lde<n .or. ne<n) then
47
+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR, &
48
+ 'invalid matrix sizes: A=',[lda,n], &
49
+ ' E=',[lde,ne])
50
+ else
51
+ E(:n, :n) = A(:n, :n) ; call stdlib_linalg_${ri}$_expm_inplace(E, order, err)
52
+ endif
53
+
54
+ ! Process output and return
55
+ call linalg_error_handling(err0,err)
56
+
57
+ return
58
+ end subroutine stdlib_linalg_${ri}$_expm
59
+
60
+ module subroutine stdlib_linalg_${ri}$_expm_inplace(A, order, err)
61
+ !> Input matrix A(n, n) / Output matrix exponential.
62
+ ${rt}$, intent(inout) :: A(:, :)
63
+ !> [optional] Order of the Pade approximation.
64
+ integer(ilp), optional, intent(in) :: order
65
+ !> [optional] State return flag.
66
+ type(linalg_state_type), optional, intent(out) :: err
27
67
28
68
! Internal variables.
29
- ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
30
- real(${rk}$) :: a_norm, c
31
- integer(ilp) :: m, n, ee, k, s, order_, i, j
32
- logical(lk) :: p
33
- type(linalg_state_type) :: err0
69
+ ${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
70
+ real(${rk}$) :: a_norm, c
71
+ integer(ilp) :: m, n, ee, k, s, order_, i, j
72
+ logical(lk) :: p
73
+ type(linalg_state_type) :: err0
34
74
35
75
! Deal with optional args.
36
76
order_ = 10 ; if (present(order)) order_ = order
@@ -40,82 +80,80 @@ contains
40
80
41
81
if (m /= n) then
42
82
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n])
43
- call linalg_error_handling(err0, err)
44
- return
45
83
else if (order_ < 0) then
46
84
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation &
47
85
needs to be positive, order=', order_)
48
- call linalg_error_handling(err0, err)
49
- return
50
- endif
86
+ else
87
+ ! Compute the L-infinity norm.
88
+ a_norm = mnorm(A, "inf")
51
89
52
- ! Compute the L-infinity norm.
53
- a_norm = mnorm(A, "inf")
54
-
55
- ! Determine scaling factor for the matrix.
56
- ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
57
- s = max(0, ee+1)
58
-
59
- ! Scale the input matrix & initialize polynomial.
60
- A2 = A/2.0_${rk}$**s ; X = A2
61
-
62
- ! First step of the Pade approximation.
63
- c = 0.5_${rk}$
64
- allocate (E, source=A2) ; allocate (Q, source=A2)
65
- do concurrent(i=1:n, j=1:n)
66
- E(i, j) = merge(1.0_${rk}$ + c*E(i, j), c*E(i, j), i == j)
67
- Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
68
- enddo
69
-
70
- ! Iteratively compute the Pade approximation.
71
- block
72
- ${rt}$ :: X_tmp(n, n)
73
- p = .true.
74
- do k = 2, order_
75
- c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
76
- X_tmp = X
77
- #:if rt.startswith('complex')
78
- call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
79
- #:else
80
- call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
81
- #:endif
82
- do concurrent(i=1:n, j=1:n)
83
- E(i, j) = E(i, j) + c*X(i, j) ! E = E + c*X
84
- enddo
85
- if (p) then
86
- do concurrent(i=1:n, j=1:n)
87
- Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
88
- enddo
89
- else
90
+ ! Determine scaling factor for the matrix.
91
+ ee = int(log(a_norm) / log2_${rk}$, kind=ilp) + 1
92
+ s = max(0, ee+1)
93
+
94
+ ! Scale the input matrix & initialize polynomial.
95
+ A2 = A/2.0_${rk}$**s ; X = A2
96
+
97
+ ! First step of the Pade approximation.
98
+ c = 0.5_${rk}$
99
+ allocate (Q, source=A2) ; A = A2
100
+ do concurrent(i=1:n, j=1:n)
101
+ A(i, j) = merge(1.0_${rk}$ + c*A(i, j), c*A(i, j), i == j)
102
+ Q(i, j) = merge(1.0_${rk}$ - c*Q(i, j), -c*Q(i, j), i == j)
103
+ enddo
104
+
105
+ ! Iteratively compute the Pade approximation.
106
+ block
107
+ ${rt}$ :: X_tmp(n, n)
108
+ p = .true.
109
+ do k = 2, order_
110
+ c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111
+ X_tmp = X
112
+ #:if rt.startswith('complex')
113
+ call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114
+ #:else
115
+ call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116
+ #:endif
90
117
do concurrent(i=1:n, j=1:n)
91
- Q (i, j) = Q (i, j) - c*X(i, j) ! Q = Q - c*X
118
+ A (i, j) = A (i, j) + c*X(i, j) ! E = E + c*X
92
119
enddo
93
- endif
94
- p = .not. p
95
- enddo
96
- end block
97
-
98
- block
99
- integer(ilp) :: ipiv(n), info
100
- call gesv(n, n, Q, n, ipiv, E, n, info) ! E = inv(Q) @ E
101
- call handle_gesv_info(this, info, n, n, n, err0)
102
- call linalg_error_handling(err0, err)
103
- end block
104
-
105
- ! Matrix squaring.
106
- block
107
- ${rt}$ :: E_tmp(n, n)
108
- do k = 1, s
109
- E_tmp = E
110
- #:if rt.startswith('complex')
111
- call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, E, n)
112
- #:else
113
- call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, E, n)
114
- #:endif
115
- enddo
116
- end block
120
+ if (p) then
121
+ do concurrent(i=1:n, j=1:n)
122
+ Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
123
+ enddo
124
+ else
125
+ do concurrent(i=1:n, j=1:n)
126
+ Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127
+ enddo
128
+ endif
129
+ p = .not. p
130
+ enddo
131
+ end block
132
+
133
+ block
134
+ integer(ilp) :: ipiv(n), info
135
+ call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E
136
+ call handle_gesv_info(this, info, n, n, n, err0)
137
+ end block
138
+
139
+ ! Matrix squaring.
140
+ block
141
+ ${rt}$ :: E_tmp(n, n)
142
+ do k = 1, s
143
+ E_tmp = A
144
+ #:if rt.startswith('complex')
145
+ call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
146
+ #:else
147
+ call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
148
+ #:endif
149
+ enddo
150
+ end block
151
+ endif
152
+
153
+ call linalg_error_handling(err0, err)
154
+
117
155
return
118
- end function stdlib_expm_ ${ri}$
156
+ end subroutine stdlib_linalg_ ${ri}$_expm_inplace
119
157
#:endfor
120
158
121
159
end submodule stdlib_linalg_matrix_functions
0 commit comments