Skip to content

Commit a2afe6b

Browse files
committed
base implementation
1 parent 06bfe4b commit a2afe6b

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ set(fppFiles
2424
stdlib_linalg_outer_product.fypp
2525
stdlib_linalg_kronecker.fypp
2626
stdlib_linalg_cross_product.fypp
27+
stdlib_linalg_solve.fypp
2728
stdlib_linalg_state.fypp
2829
stdlib_optval.fypp
2930
stdlib_selection.fypp

src/stdlib_linalg_solve.fypp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#:include "common.fypp"
2+
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
3+
#:set RHS_SUFFIX = ["one","many"]
4+
#:set RHS_SYMBOL = [ranksuffix(r) for r in [1,2]]
5+
#:set RHS_EMPTY = [emptyranksuffix(r) for r in [1,2]]
6+
#:set ALL_RHS = list(zip(RHS_SYMBOL,RHS_SUFFIX,RHS_EMPTY))
7+
module stdlib_linalg_solve
8+
use stdlib_linalg_constants
9+
use stdlib_linalg_lapack, only: gesv
10+
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
11+
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
12+
implicit none(type,external)
13+
private
14+
15+
!> Solve a linear system
16+
public :: solve
17+
18+
! NumPy: solve(a, b)
19+
! Scipy: solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, check_finite=True, assume_a='gen', transposed=False)[source]#
20+
! IMSL: lu_solve(a, b, transpose=False)
21+
22+
interface solve
23+
#:for nd,ndsuf,nde in ALL_RHS
24+
#:for rk,rt,ri in RC_KINDS_TYPES
25+
module procedure stdlib_linalg_${ri}$solve${ndsuf}$
26+
#:endfor
27+
#:endfor
28+
end interface solve
29+
30+
31+
contains
32+
33+
#:for nd,ndsuf,nde in ALL_RHS
34+
#:for rk,rt,ri in RC_KINDS_TYPES
35+
! Compute the solution to a real system of linear equations A * X = B
36+
function stdlib_linalg_${ri}$solve${ndsuf}$(a,b,overwrite_a,err) result(x)
37+
!> Input matrix a[n,n]
38+
${rt}$, intent(inout), target :: a(:,:)
39+
!> Right hand side vector or array, b[n] or b[n,nrhs]
40+
${rt}$, intent(in) :: b${nd}$
41+
!> [optional] Can A data be overwritten and destroyed?
42+
logical(lk), optional, intent(in) :: overwrite_a
43+
!> [optional] state return flag. On error if not requested, the code will stop
44+
type(linalg_state_type), optional, intent(out) :: err
45+
!> Result array/matrix x[n] or x[n,nrhs]
46+
${rt}$, allocatable, target :: x${nd}$
47+
48+
!> Local variables
49+
type(linalg_state_type) :: err0
50+
integer(ilp) :: lda,n,ldb,nrhs,info
51+
integer(ilp), allocatable :: ipiv(:)
52+
logical(lk) :: copy_a
53+
${rt}$, pointer :: xmat(:,:),amat(:,:)
54+
character(*), parameter :: this = 'solve'
55+
56+
!> Problem sizes
57+
lda = size(a,1,kind=ilp)
58+
n = size(a,2,kind=ilp)
59+
ldb = size(b,1,kind=ilp)
60+
nrhs = size(b ,kind=ilp)/ldb
61+
62+
if (lda<1 .or. n<1 .or. ldb<1 .or. lda/=n .or. ldb/=n) then
63+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=[',lda,',',n,'],',&
64+
'b=[',ldb,',',nrhs,']')
65+
allocate(x${nde}$)
66+
goto 1
67+
end if
68+
69+
! Can A be overwritten? By default, do not overwrite
70+
if (present(overwrite_a)) then
71+
copy_a = .not.overwrite_a
72+
else
73+
copy_a = .true._lk
74+
endif
75+
76+
! Pivot indices
77+
allocate(ipiv(n))
78+
79+
! Initialize a matrix temporary
80+
if (copy_a) then
81+
allocate(amat(lda,n),source=a)
82+
else
83+
amat => a
84+
endif
85+
86+
! Initialize solution with the rhs
87+
allocate(x,source=b)
88+
xmat(1:n,1:nrhs) => x
89+
90+
! Solve system
91+
call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)
92+
93+
! Process output
94+
select case (info)
95+
case (0)
96+
! Success
97+
case (-1)
98+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size n=',n)
99+
case (-2)
100+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid rhs size n=',nrhs)
101+
case (-4)
102+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid matrix size a=[',lda,',',n,']')
103+
case (-7)
104+
err0 = linalg_state_type(this,LINALG_ERROR,'invalid matrix size a=[',lda,',',n,']')
105+
case (1:)
106+
err0 = linalg_state_type(this,LINALG_ERROR,'singular matrix')
107+
case default
108+
err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
109+
end select
110+
111+
if (.not.copy_a) deallocate(amat)
112+
113+
! Process output and return
114+
1 call linalg_error_handling(err0,err)
115+
116+
end function stdlib_linalg_${ri}$solve${ndsuf}$
117+
118+
119+
#:endfor
120+
#:endfor
121+
122+
end module stdlib_linalg_solve

0 commit comments

Comments
 (0)