@@ -55,27 +55,73 @@ submodule (stdlib_linalg) stdlib_linalg_solve
55
55
type(linalg_state_type), intent(out) :: err
56
56
!> Result array/matrix x[n] or x[n,nrhs]
57
57
${rt}$, allocatable, target :: x${nd}$
58
+
59
+ ! Initialize solution shape from the rhs array
60
+ allocate(x,mold=b)
61
+
62
+ call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,overwrite_a=overwrite_a,err=err)
63
+
64
+ end function stdlib_linalg_${ri}$_solve_${ndsuf}$
65
+
66
+ !> Compute the solution to a real system of linear equations A * X = B (pure interface)
67
+ pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
68
+ !> Input matrix a[n,n]
69
+ ${rt}$, intent(in) :: a(:,:)
70
+ !> Right hand side vector or array, b[n] or b[n,nrhs]
71
+ ${rt}$, intent(in) :: b${nd}$
72
+ !> Result array/matrix x[n] or x[n,nrhs]
73
+ ${rt}$, allocatable, target :: x${nd}$
74
+
75
+ ! Local variables
76
+ ${rt}$, allocatable :: amat(:,:)
77
+
78
+ ! Copy `a` so it can be intent(in)
79
+ allocate(amat,source=a)
80
+
81
+ ! Initialize solution shape from the rhs array
82
+ allocate(x,mold=b)
83
+
84
+ call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(amat,b,x,overwrite_a=.true.)
85
+
86
+ end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
87
+
88
+ !> Compute the solution to a real system of linear equations A * X = B (pure interface)
89
+ pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
90
+ !> Input matrix a[n,n]
91
+ ${rt}$, intent(inout), target :: a(:,:)
92
+ !> Right hand side vector or array, b[n] or b[n,nrhs]
93
+ ${rt}$, intent(in) :: b${nd}$
94
+ !> Result array/matrix x[n] or x[n,nrhs]
95
+ ${rt}$, intent(inout), contiguous, target :: x${nd}$
96
+ !> [optional] Storage array for the diagonal pivot indices
97
+ integer(ilp), optional, intent(inout), target :: pivot(:)
98
+ !> [optional] Can A data be overwritten and destroyed?
99
+ logical(lk), optional, intent(in) :: overwrite_a
100
+ !> [optional] state return flag. On error if not requested, the code will stop
101
+ type(linalg_state_type), optional, intent(out) :: err
58
102
59
103
! Local variables
60
104
type(linalg_state_type) :: err0
61
- integer(ilp) :: lda,n,ldb,nrhs,info
62
- integer(ilp), allocatable :: ipiv(:)
105
+ integer(ilp) :: lda,n,ldb,ldx,nrhsx, nrhs,info,npiv
106
+ integer(ilp), pointer :: ipiv(:)
63
107
logical(lk) :: copy_a
64
108
${rt}$, pointer :: xmat(:,:),amat(:,:)
65
109
66
110
! Problem sizes
67
- lda = size(a,1,kind=ilp)
68
- n = size(a,2,kind=ilp)
69
- ldb = size(b,1,kind=ilp)
70
- nrhs = size(b ,kind=ilp)/ldb
71
-
72
- if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
73
- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
74
- ', b=',[ldb,nrhs])
75
- allocate(x${nde}$)
76
- call linalg_error_handling(err0,err)
77
- return
78
- end if
111
+ lda = size(a,1,kind=ilp)
112
+ n = size(a,2,kind=ilp)
113
+ ldb = size(b,1,kind=ilp)
114
+ nrhs = size(b ,kind=ilp)/ldb
115
+ ldx = size(x,1,kind=ilp)
116
+ nrhsx = size(x ,kind=ilp)/ldx
117
+
118
+ ! Has a pre-allocated pivots storage array been provided?
119
+ if (present(pivot)) then
120
+ ipiv => pivot
121
+ else
122
+ allocate(ipiv(n))
123
+ endif
124
+ npiv = size(ipiv,kind=ilp)
79
125
80
126
! Can A be overwritten? By default, do not overwrite
81
127
if (present(overwrite_a)) then
@@ -84,8 +130,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
84
130
copy_a = .true._lk
85
131
endif
86
132
87
- ! Pivot indices
88
- allocate(ipiv(n))
133
+ if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs .or. npiv/=n) then
134
+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
135
+ 'b=',[ldb,nrhs],' x=',[ldx,nrhsx], &
136
+ 'pivot=',n)
137
+ call linalg_error_handling(err0,err)
138
+ return
139
+ end if
89
140
90
141
! Initialize a matrix temporary
91
142
if (copy_a) then
@@ -95,7 +146,7 @@ submodule (stdlib_linalg) stdlib_linalg_solve
95
146
endif
96
147
97
148
! Initialize solution with the rhs
98
- allocate(x,source=b)
149
+ x = b
99
150
xmat(1:n,1:nrhs) => x
100
151
101
152
! Solve system
@@ -105,64 +156,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
105
156
call handle_gesv_info(info,lda,n,nrhs,err0)
106
157
107
158
if (copy_a) deallocate(amat)
159
+ if (.not.present(pivot)) deallocate(ipiv)
108
160
109
161
! Process output and return
110
162
call linalg_error_handling(err0,err)
111
163
112
- end function stdlib_linalg_${ri}$_solve_${ndsuf}$
113
-
114
- !> Compute the solution to a real system of linear equations A * X = B (pure interface)
115
- pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
116
- !> Input matrix a[n,n]
117
- ${rt}$, intent(in), target :: a(:,:)
118
- !> Right hand side vector or array, b[n] or b[n,nrhs]
119
- ${rt}$, intent(in) :: b${nd}$
120
- !> Result array/matrix x[n] or x[n,nrhs]
121
- ${rt}$, allocatable, target :: x${nd}$
122
-
123
- ! Local variables
124
- type(linalg_state_type) :: err0
125
- integer(ilp) :: lda,n,ldb,nrhs,info
126
- integer(ilp), allocatable :: ipiv(:)
127
- ${rt}$, pointer :: xmat(:,:)
128
- ${rt}$, allocatable :: amat(:,:)
129
-
130
- ! Problem sizes
131
- lda = size(a,1,kind=ilp)
132
- n = size(a,2,kind=ilp)
133
- ldb = size(b,1,kind=ilp)
134
- nrhs = size(b ,kind=ilp)/ldb
135
-
136
- if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
137
- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
138
- ', b=',[ldb,nrhs])
139
- allocate(x${nde}$)
140
- call linalg_error_handling(err0)
141
- return
142
- end if
143
-
144
- ! Pivot indices
145
- allocate(ipiv(n))
146
-
147
- ! Initialize a matrix temporary
148
- allocate(amat,source=a)
149
-
150
- ! Initialize solution with the rhs
151
- allocate(x,source=b)
152
- xmat(1:n,1:nrhs) => x
153
-
154
- ! Solve system
155
- call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)
156
-
157
- ! Process output
158
- call handle_gesv_info(info,lda,n,nrhs,err0)
159
-
160
- deallocate(amat)
161
-
162
- ! Process output and return
163
- call linalg_error_handling(err0)
164
-
165
- end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
164
+ end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$
165
+
166
166
#:endif
167
167
#:endfor
168
168
#:endfor
0 commit comments