Skip to content

Commit df6483a

Browse files
committed
Base type and constructor interfaces for symtridiagonal matrices.
1 parent 3385843 commit df6483a

File tree

2 files changed

+372
-0
lines changed

2 files changed

+372
-0
lines changed

src/stdlib_specialmatrices.fypp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ module stdlib_specialmatrices
3535
end type
3636
#:endfor
3737

38+
!--> Symmetric Tridiagonal matrices
39+
#:for k1, t1, s1 in (KINDS_TYPES)
40+
type, public :: symtridiagonal_${s1}$_type
41+
!! Base type to define a `symtridiagonal` matrix.
42+
private
43+
${t1}$, allocatable :: dv(:), ev(:)
44+
integer(ilp) :: n
45+
end type
46+
#:endfor
47+
3848
!--------------------------------
3949
!----- -----
4050
!----- CONSTRUCTORS -----
@@ -126,6 +136,91 @@ module stdlib_specialmatrices
126136
#:endfor
127137
end interface
128138

139+
interface symtridiagonal
140+
!! ([Specifications](../page/specs/stdlib_specialmatrices.html#SymTridiagonal)) This
141+
!! interface provides different methods to construct a `sylmtridiagonal`
142+
!! matrix. Only the non-zero elements of \( A \) are stored, i.e.
143+
!!
144+
!! \[
145+
!! A
146+
!! =
147+
!! \begin{bmatrix}
148+
!! a_1 & b_1 \\
149+
!! b_1 & a_2 & b_2 \\
150+
!! & \ddots & \ddots & \ddots \\
151+
!! & & b_{n-2} & a_{n-1} & b_{n-1} \\
152+
!! & & & b_{n-1} & a_n
153+
!! \end{bmatrix}.
154+
!! \]
155+
!!
156+
!! #### Syntax
157+
!!
158+
!! - Construct a real `symtridiagonal` matrix from rank-1 arrays:
159+
!!
160+
!! ```fortran
161+
!! integer, parameter :: n
162+
!! real(dp), allocatable :: dv(:), ev(:)
163+
!! type(symtridiagonal_rdp_type) :: A
164+
!! integer :: i
165+
!!
166+
!! ev = [(i, i=1, n-1)]; dv = [(2*i, i=1, n)]
167+
!! A = SymTridiagonal(dv, ev)
168+
!! ```
169+
!!
170+
!! - Construct a real `symtridiagonal` matrix with constant diagonals:
171+
!!
172+
!! ```fortran
173+
!! integer, parameter :: n
174+
!! real(dp), parameter :: a = 1.0_dp, b = 1.0_dp
175+
!! type(symtridiagonal_rdp_type) :: A
176+
!!
177+
!! A = SymTridiagonal(a, b, n)
178+
!! ```
179+
#:for k1, t1, s1 in (KINDS_TYPES)
180+
pure module function initialize_symtridiagonal_pure_${s1}$(dv, ev) result(A)
181+
!! Construct a `tridiagonal` matrix from the rank-1 arrays
182+
!! `dl`, `dv` and `du`.
183+
${t1}$, intent(in) :: dv(:), ev(:)
184+
!! SymTridiagonal matrix elements.
185+
type(symtridiagonal_${s1}$_type) :: A
186+
!! Corresponding SymTridiagonal matrix.
187+
end function
188+
189+
pure module function initialize_constant_symtridiagonal_pure_${s1}$(dv, ev, n) result(A)
190+
!! Construct a `symtridiagonal` matrix with constant elements.
191+
${t1}$, intent(in) :: dv, ev
192+
!! SymTridiagonal matrix elements.
193+
integer(ilp), intent(in) :: n
194+
!! Matrix dimension.
195+
type(symtridiagonal_${s1}$_type) :: A
196+
!! Corresponding SymTridiagonal matrix.
197+
end function
198+
199+
module function initialize_symtridiagonal_impure_${s1}$(dv, ev, err) result(A)
200+
!! Construct a `symtridiagonal` matrix from the rank-1 arrays
201+
!! `dl`, `dv` and `du`.
202+
${t1}$, intent(in) :: dv(:), ev(:)
203+
!! Tridiagonal matrix elements.
204+
type(linalg_state_type), intent(out) :: err
205+
!! Error handling.
206+
type(symtridiagonal_${s1}$_type) :: A
207+
!! Corresponding SymTridiagonal matrix.
208+
end function
209+
210+
module function initialize_constant_symtridiagonal_impure_${s1}$(dv, ev, n, err) result(A)
211+
!! Construct a `symtridiagonal` matrix with constant elements.
212+
${t1}$, intent(in) :: dv, ev
213+
!! Tridiagonal matrix elements.
214+
integer(ilp), intent(in) :: n
215+
!! Matrix dimension.
216+
type(linalg_state_type), intent(out) :: err
217+
!! Error handling.
218+
type(symtridiagonal_${s1}$_type) :: A
219+
!! Corresponding SymTridiagonal matrix.
220+
end function
221+
#:endfor
222+
end interface
223+
129224
!----------------------------------
130225
!----- -----
131226
!----- LINEAR ALGEBRA -----
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
#:include "common.fypp"
2+
#:set RANKS = range(1, 2+1)
3+
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
4+
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
5+
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
6+
submodule (stdlib_specialmatrices) tridiagonal_matrices
7+
use stdlib_linalg_lapack, only: lagtm
8+
9+
character(len=*), parameter :: this = "tridiagonal matrices"
10+
contains
11+
12+
!--------------------------------
13+
!----- -----
14+
!----- CONSTRUCTORS -----
15+
!----- -----
16+
!--------------------------------
17+
18+
#:for k1, t1, s1 in (KINDS_TYPES)
19+
pure module function initialize_tridiagonal_pure_${s1}$(dl, dv, du) result(A)
20+
!! Construct a `tridiagonal` matrix from the rank-1 arrays
21+
!! `dl`, `dv` and `du`.
22+
${t1}$, intent(in) :: dl(:), dv(:), du(:)
23+
!! tridiagonal matrix elements.
24+
type(tridiagonal_${s1}$_type) :: A
25+
!! Corresponding tridiagonal matrix.
26+
27+
! Internal variables.
28+
integer(ilp) :: n
29+
type(linalg_state_type) :: err0
30+
31+
! Sanity check.
32+
n = size(dv, kind=ilp)
33+
if (n <= 0) then
34+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
35+
call linalg_error_handling(err0)
36+
endif
37+
if (size(dl, kind=ilp) /= n-1) then
38+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.")
39+
call linalg_error_handling(err0)
40+
endif
41+
if (size(du, kind=ilp) /= n-1) then
42+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
43+
call linalg_error_handling(err0)
44+
endif
45+
46+
! Description of the matrix.
47+
A%n = n
48+
! Matrix elements.
49+
A%dl = dl ; A%dv = dv ; A%du = du
50+
end function
51+
52+
pure module function initialize_constant_tridiagonal_pure_${s1}$(dl, dv, du, n) result(A)
53+
!! Construct a `tridiagonal` matrix with constant elements.
54+
${t1}$, intent(in) :: dl, dv, du
55+
!! tridiagonal matrix elements.
56+
integer(ilp), intent(in) :: n
57+
!! Matrix dimension.
58+
type(tridiagonal_${s1}$_type) :: A
59+
!! Corresponding tridiagonal matrix.
60+
61+
! Internal variables.
62+
integer(ilp) :: i
63+
type(linalg_state_type) :: err0
64+
65+
! Description of the matrix.
66+
A%n = n
67+
if (n <= 0) then
68+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
69+
call linalg_error_handling(err0)
70+
endif
71+
! Matrix elements.
72+
A%dl = [(dl, i = 1, n-1)]
73+
A%dv = [(dv, i = 1, n)]
74+
A%du = [(du, i = 1, n-1)]
75+
end function
76+
77+
module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
78+
!! Construct a `tridiagonal` matrix from the rank-1 arrays
79+
!! `dl`, `dv` and `du`.
80+
${t1}$, intent(in) :: dl(:), dv(:), du(:)
81+
!! tridiagonal matrix elements.
82+
type(linalg_state_type), intent(out) :: err
83+
!! Error handling.
84+
type(tridiagonal_${s1}$_type) :: A
85+
!! Corresponding tridiagonal matrix.
86+
87+
! Internal variables.
88+
integer(ilp) :: n
89+
type(linalg_state_type) :: err0
90+
91+
! Sanity check.
92+
n = size(dv, kind=ilp)
93+
if (n <= 0) then
94+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
95+
call linalg_error_handling(err0, err)
96+
endif
97+
if (size(dl, kind=ilp) /= n-1) then
98+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector dl does not have the correct length.")
99+
call linalg_error_handling(err0, err)
100+
endif
101+
if (size(du, kind=ilp) /= n-1) then
102+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Vector du does not have the correct length.")
103+
call linalg_error_handling(err0, err)
104+
endif
105+
106+
! Description of the matrix.
107+
A%n = n
108+
! Matrix elements.
109+
A%dl = dl ; A%dv = dv ; A%du = du
110+
end function
111+
112+
module function initialize_constant_tridiagonal_impure_${s1}$(dl, dv, du, n, err) result(A)
113+
!! Construct a `tridiagonal` matrix with constant elements.
114+
${t1}$, intent(in) :: dl, dv, du
115+
!! tridiagonal matrix elements.
116+
integer(ilp), intent(in) :: n
117+
!! Matrix dimension.
118+
type(linalg_state_type), intent(out) :: err
119+
!! Error handling
120+
type(tridiagonal_${s1}$_type) :: A
121+
!! Corresponding tridiagonal matrix.
122+
123+
! Internal variables.
124+
integer(ilp) :: i
125+
type(linalg_state_type) :: err0
126+
127+
! Description of the matrix.
128+
A%n = n
129+
if (n <= 0) then
130+
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
131+
call linalg_error_handling(err0, err)
132+
endif
133+
! Matrix elements.
134+
A%dl = [(dl, i = 1, n-1)]
135+
A%dv = [(dv, i = 1, n)]
136+
A%du = [(du, i = 1, n-1)]
137+
end function
138+
#:endfor
139+
140+
!-----------------------------------------
141+
!----- -----
142+
!----- MATRIX-VECTOR PRODUCT -----
143+
!----- -----
144+
!-----------------------------------------
145+
146+
!! spmv_tridiag
147+
#:for k1, t1, s1 in (KINDS_TYPES)
148+
#:for rank in RANKS
149+
module subroutine spmv_tridiag_${rank}$d_${s1}$(A, x, y, alpha, beta, op)
150+
type(tridiagonal_${s1}$_type), intent(in) :: A
151+
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
152+
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
153+
real(${k1}$), intent(in), optional :: alpha
154+
real(${k1}$), intent(in), optional :: beta
155+
character(1), intent(in), optional :: op
156+
157+
! Internal variables.
158+
real(${k1}$) :: alpha_, beta_
159+
integer(ilp) :: n, nrhs, ldx, ldy
160+
character(1) :: op_
161+
#:if rank == 1
162+
${t1}$, pointer :: xmat(:, :), ymat(:, :)
163+
#:endif
164+
165+
! Deal with optional arguments.
166+
alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha
167+
beta_ = 0.0_${k1}$ ; if (present(beta)) beta_ = beta
168+
op_ = "N" ; if (present(op)) op_ = op
169+
170+
! Prepare Lapack arguments.
171+
n = A%n ; ldx = n ; ldy = n ; y = 0.0_${k1}$
172+
nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}#
173+
174+
#:if rank == 1
175+
! Pointer trick.
176+
xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y
177+
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
178+
#:else
179+
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
180+
#:endif
181+
end subroutine
182+
#:endfor
183+
#:endfor
184+
185+
!-------------------------------------
186+
!----- -----
187+
!----- UTILITY FUNCTIONS -----
188+
!----- -----
189+
!-------------------------------------
190+
191+
#:for k1, t1, s1 in (KINDS_TYPES)
192+
pure module function tridiagonal_to_dense_${s1}$(A) result(B)
193+
!! Convert a `tridiagonal` matrix to its dense representation.
194+
type(tridiagonal_${s1}$_type), intent(in) :: A
195+
!! Input tridiagonal matrix.
196+
${t1}$, allocatable :: B(:, :)
197+
!! Corresponding dense matrix.
198+
199+
! Internal variables.
200+
integer(ilp) :: i
201+
202+
associate (n => A%n)
203+
#:if t1.startswith('complex')
204+
allocate(B(n, n), source=zero_c${k1}$)
205+
#:else
206+
allocate(B(n, n), source=zero_${k1}$)
207+
#:endif
208+
B(1, 1) = A%dv(1) ; B(1, 2) = A%du(1)
209+
do concurrent (i=2:n-1)
210+
B(i, i-1) = A%dl(i-1)
211+
B(i, i) = A%dv(i)
212+
B(i, i+1) = A%du(i)
213+
enddo
214+
B(n, n-1) = A%dl(n-1) ; B(n, n) = A%dv(n)
215+
end associate
216+
end function
217+
#:endfor
218+
219+
#:for k1, t1, s1 in (KINDS_TYPES)
220+
pure module function transpose_tridiagonal_${s1}$(A) result(B)
221+
type(tridiagonal_${s1}$_type), intent(in) :: A
222+
!! Input matrix.
223+
type(tridiagonal_${s1}$_type) :: B
224+
B = tridiagonal(A%du, A%dv, A%dl)
225+
end function
226+
#:endfor
227+
228+
#:for k1, t1, s1 in (KINDS_TYPES)
229+
pure module function hermitian_tridiagonal_${s1}$(A) result(B)
230+
type(tridiagonal_${s1}$_type), intent(in) :: A
231+
!! Input matrix.
232+
type(tridiagonal_${s1}$_type) :: B
233+
#:if t1.startswith("complex")
234+
B = tridiagonal(conjg(A%du), conjg(A%dv), conjg(A%dl))
235+
#:else
236+
B = tridiagonal(A%du, A%dv, A%dl)
237+
#:endif
238+
end function
239+
#:endfor
240+
241+
#:for k1, t1, s1 in (KINDS_TYPES)
242+
pure module function scalar_multiplication_tridiagonal_${s1}$(alpha, A) result(B)
243+
${t1}$, intent(in) :: alpha
244+
type(tridiagonal_${s1}$_type), intent(in) :: A
245+
type(tridiagonal_${s1}$_type) :: B
246+
B = tridiagonal(A%dl, A%dv, A%du)
247+
B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du
248+
end function
249+
250+
pure module function scalar_multiplication_bis_tridiagonal_${s1}$(A, alpha) result(B)
251+
type(tridiagonal_${s1}$_type), intent(in) :: A
252+
${t1}$, intent(in) :: alpha
253+
type(tridiagonal_${s1}$_type) :: B
254+
B = tridiagonal(A%dl, A%dv, A%du)
255+
B%dl = alpha*B%dl; B%dv = alpha*B%dv; B%du = alpha*B%du
256+
end function
257+
#:endfor
258+
259+
#:for k1, t1, s1 in (KINDS_TYPES)
260+
pure module function matrix_add_tridiagonal_${s1}$(A, B) result(C)
261+
type(tridiagonal_${s1}$_type), intent(in) :: A
262+
type(tridiagonal_${s1}$_type), intent(in) :: B
263+
type(tridiagonal_${s1}$_type) :: C
264+
C = tridiagonal(A%dl, A%dv, A%du)
265+
C%dl = C%dl + B%dl; C%dv = C%dv + B%dv; C%du = C%du + B%du
266+
end function
267+
268+
pure module function matrix_sub_tridiagonal_${s1}$(A, B) result(C)
269+
type(tridiagonal_${s1}$_type), intent(in) :: A
270+
type(tridiagonal_${s1}$_type), intent(in) :: B
271+
type(tridiagonal_${s1}$_type) :: C
272+
C = tridiagonal(A%dl, A%dv, A%du)
273+
C%dl = C%dl - B%dl; C%dv = C%dv - B%dv; C%du = C%du - B%du
274+
end function
275+
#:endfor
276+
277+
end submodule

0 commit comments

Comments
 (0)