Skip to content

Commit 316c44a

Browse files
committed
implement subroutine interface
1 parent bc13246 commit 316c44a

File tree

2 files changed

+96
-72
lines changed

2 files changed

+96
-72
lines changed

src/stdlib_linalg.fypp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module stdlib_linalg
2121
public :: diag
2222
public :: eye
2323
public :: solve
24+
public :: solve_lu
2425
public :: trace
2526
public :: outer_product
2627
public :: kronecker_product
@@ -264,7 +265,7 @@ module stdlib_linalg
264265
end function stdlib_linalg_${ri}$_solve_${ndsuf}$
265266
pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
266267
!> Input matrix a[n,n]
267-
${rt}$, intent(in), target :: a(:,:)
268+
${rt}$, intent(in) :: a(:,:)
268269
!> Right hand side vector or array, b[n] or b[n,nrhs]
269270
${rt}$, intent(in) :: b${nd}$
270271
!> Result array/matrix x[n] or x[n,nrhs]
@@ -275,6 +276,29 @@ module stdlib_linalg
275276
#:endfor
276277
end interface solve
277278

279+
interface solve_lu
280+
#:for nd,ndsuf,nde in ALL_RHS
281+
#:for rk,rt,ri in RC_KINDS_TYPES
282+
#:if rk!="xdp"
283+
pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
284+
!> Input matrix a[n,n]
285+
${rt}$, intent(inout), target :: a(:,:)
286+
!> Right hand side vector or array, b[n] or b[n,nrhs]
287+
${rt}$, intent(in) :: b${nd}$
288+
!> Result array/matrix x[n] or x[n,nrhs]
289+
${rt}$, intent(inout), contiguous, target :: x${nd}$
290+
!> [optional] Storage array for the diagonal pivot indices
291+
integer(ilp), optional, intent(inout), target :: pivot(:)
292+
!> [optional] Can A data be overwritten and destroyed?
293+
logical(lk), optional, intent(in) :: overwrite_a
294+
!> [optional] state return flag. On error if not requested, the code will stop
295+
type(linalg_state_type), optional, intent(out) :: err
296+
end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$
297+
#:endif
298+
#:endfor
299+
#:endfor
300+
end interface solve_lu
301+
278302
interface det
279303
!! version: experimental
280304
!!

src/stdlib_linalg_solve.fypp

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,73 @@ submodule (stdlib_linalg) stdlib_linalg_solve
5555
type(linalg_state_type), intent(out) :: err
5656
!> Result array/matrix x[n] or x[n,nrhs]
5757
${rt}$, allocatable, target :: x${nd}$
58+
59+
! Initialize solution shape from the rhs array
60+
allocate(x,mold=b)
61+
62+
call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,overwrite_a=overwrite_a,err=err)
63+
64+
end function stdlib_linalg_${ri}$_solve_${ndsuf}$
65+
66+
!> Compute the solution to a real system of linear equations A * X = B (pure interface)
67+
pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
68+
!> Input matrix a[n,n]
69+
${rt}$, intent(in) :: a(:,:)
70+
!> Right hand side vector or array, b[n] or b[n,nrhs]
71+
${rt}$, intent(in) :: b${nd}$
72+
!> Result array/matrix x[n] or x[n,nrhs]
73+
${rt}$, allocatable, target :: x${nd}$
74+
75+
! Local variables
76+
${rt}$, allocatable :: amat(:,:)
77+
78+
! Copy `a` so it can be intent(in)
79+
allocate(amat,source=a)
80+
81+
! Initialize solution shape from the rhs array
82+
allocate(x,mold=b)
83+
84+
call stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(amat,b,x,overwrite_a=.true.)
85+
86+
end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
87+
88+
!> Compute the solution to a real system of linear equations A * X = B (pure interface)
89+
pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
90+
!> Input matrix a[n,n]
91+
${rt}$, intent(inout), target :: a(:,:)
92+
!> Right hand side vector or array, b[n] or b[n,nrhs]
93+
${rt}$, intent(in) :: b${nd}$
94+
!> Result array/matrix x[n] or x[n,nrhs]
95+
${rt}$, intent(inout), contiguous, target :: x${nd}$
96+
!> [optional] Storage array for the diagonal pivot indices
97+
integer(ilp), optional, intent(inout), target :: pivot(:)
98+
!> [optional] Can A data be overwritten and destroyed?
99+
logical(lk), optional, intent(in) :: overwrite_a
100+
!> [optional] state return flag. On error if not requested, the code will stop
101+
type(linalg_state_type), optional, intent(out) :: err
58102

59103
! Local variables
60104
type(linalg_state_type) :: err0
61-
integer(ilp) :: lda,n,ldb,nrhs,info
62-
integer(ilp), allocatable :: ipiv(:)
105+
integer(ilp) :: lda,n,ldb,ldx,nrhsx,nrhs,info,npiv
106+
integer(ilp), pointer :: ipiv(:)
63107
logical(lk) :: copy_a
64108
${rt}$, pointer :: xmat(:,:),amat(:,:)
65109

66110
! Problem sizes
67-
lda = size(a,1,kind=ilp)
68-
n = size(a,2,kind=ilp)
69-
ldb = size(b,1,kind=ilp)
70-
nrhs = size(b ,kind=ilp)/ldb
71-
72-
if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
73-
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
74-
', b=',[ldb,nrhs])
75-
allocate(x${nde}$)
76-
call linalg_error_handling(err0,err)
77-
return
78-
end if
111+
lda = size(a,1,kind=ilp)
112+
n = size(a,2,kind=ilp)
113+
ldb = size(b,1,kind=ilp)
114+
nrhs = size(b ,kind=ilp)/ldb
115+
ldx = size(x,1,kind=ilp)
116+
nrhsx = size(x ,kind=ilp)/ldx
117+
118+
! Has a pre-allocated pivots storage array been provided?
119+
if (present(pivot)) then
120+
ipiv => pivot
121+
else
122+
allocate(ipiv(n))
123+
endif
124+
npiv = size(ipiv,kind=ilp)
79125

80126
! Can A be overwritten? By default, do not overwrite
81127
if (present(overwrite_a)) then
@@ -84,8 +130,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
84130
copy_a = .true._lk
85131
endif
86132

87-
! Pivot indices
88-
allocate(ipiv(n))
133+
if (any([lda,n,ldb]<1) .or. any([lda,ldb,ldx]/=n) .or. nrhsx/=nrhs .or. npiv/=n) then
134+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
135+
'b=',[ldb,nrhs],' x=',[ldx,nrhsx], &
136+
'pivot=',n)
137+
call linalg_error_handling(err0,err)
138+
return
139+
end if
89140

90141
! Initialize a matrix temporary
91142
if (copy_a) then
@@ -95,7 +146,7 @@ submodule (stdlib_linalg) stdlib_linalg_solve
95146
endif
96147

97148
! Initialize solution with the rhs
98-
allocate(x,source=b)
149+
x = b
99150
xmat(1:n,1:nrhs) => x
100151

101152
! Solve system
@@ -105,64 +156,13 @@ submodule (stdlib_linalg) stdlib_linalg_solve
105156
call handle_gesv_info(info,lda,n,nrhs,err0)
106157

107158
if (copy_a) deallocate(amat)
159+
if (.not.present(pivot)) deallocate(ipiv)
108160

109161
! Process output and return
110162
call linalg_error_handling(err0,err)
111163

112-
end function stdlib_linalg_${ri}$_solve_${ndsuf}$
113-
114-
!> Compute the solution to a real system of linear equations A * X = B (pure interface)
115-
pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
116-
!> Input matrix a[n,n]
117-
${rt}$, intent(in), target :: a(:,:)
118-
!> Right hand side vector or array, b[n] or b[n,nrhs]
119-
${rt}$, intent(in) :: b${nd}$
120-
!> Result array/matrix x[n] or x[n,nrhs]
121-
${rt}$, allocatable, target :: x${nd}$
122-
123-
! Local variables
124-
type(linalg_state_type) :: err0
125-
integer(ilp) :: lda,n,ldb,nrhs,info
126-
integer(ilp), allocatable :: ipiv(:)
127-
${rt}$, pointer :: xmat(:,:)
128-
${rt}$, allocatable :: amat(:,:)
129-
130-
! Problem sizes
131-
lda = size(a,1,kind=ilp)
132-
n = size(a,2,kind=ilp)
133-
ldb = size(b,1,kind=ilp)
134-
nrhs = size(b ,kind=ilp)/ldb
135-
136-
if (any([lda,n,ldb]<1) .or. any([lda,ldb]/=n)) then
137-
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
138-
', b=',[ldb,nrhs])
139-
allocate(x${nde}$)
140-
call linalg_error_handling(err0)
141-
return
142-
end if
143-
144-
! Pivot indices
145-
allocate(ipiv(n))
146-
147-
! Initialize a matrix temporary
148-
allocate(amat,source=a)
149-
150-
! Initialize solution with the rhs
151-
allocate(x,source=b)
152-
xmat(1:n,1:nrhs) => x
153-
154-
! Solve system
155-
call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)
156-
157-
! Process output
158-
call handle_gesv_info(info,lda,n,nrhs,err0)
159-
160-
deallocate(amat)
161-
162-
! Process output and return
163-
call linalg_error_handling(err0)
164-
165-
end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
164+
end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$
165+
166166
#:endif
167167
#:endfor
168168
#:endfor

0 commit comments

Comments
 (0)