Skip to content

Commit dd9b022

Browse files
committed
Added test for pivoting QR.
1 parent 55bc413 commit dd9b022

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

test/linalg/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(
1010
"test_linalg_mnorm.fypp"
1111
"test_linalg_determinant.fypp"
1212
"test_linalg_qr.fypp"
13+
"test_linalg_pivoting_qr.fypp"
1314
"test_linalg_schur.fypp"
1415
"test_linalg_solve_iterative.fypp"
1516
"test_linalg_svd.fypp"
@@ -42,6 +43,7 @@ ADDTEST(linalg_mnorm)
4243
ADDTEST(linalg_solve)
4344
ADDTEST(linalg_lstsq)
4445
ADDTEST(linalg_qr)
46+
ADDTEST(linalg_pivoting_qr)
4547
ADDTEST(linalg_schur)
4648
ADDTEST(linalg_solve_iterative)
4749
ADDTEST(linalg_svd)
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#:include "common.fypp"
2+
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
3+
! Test QR factorization
4+
module test_linalg_pivoting_qr
5+
use testdrive, only: error_type, check, new_unittest, unittest_type
6+
use stdlib_linalg_constants
7+
use stdlib_linalg_state, only: LINALG_VALUE_ERROR,linalg_state_type
8+
use stdlib_linalg, only: qr,qr_space
9+
use ieee_arithmetic, only: ieee_value,ieee_quiet_nan
10+
11+
implicit none (type,external)
12+
13+
public :: test_pivoting_qr_factorization
14+
15+
contains
16+
17+
!> QR factorization tests
18+
subroutine test_pivoting_qr_factorization(tests)
19+
!> Collection of tests
20+
type(unittest_type), allocatable, intent(out) :: tests(:)
21+
22+
allocate(tests(0))
23+
24+
#:for rk,rt,ri in RC_KINDS_TYPES
25+
call add_test(tests,new_unittest("pivoting_qr_random_${ri}$",test_pivoting_qr_random_${ri}$))
26+
#:endfor
27+
28+
end subroutine test_pivoting_qr_factorization
29+
30+
!> QR factorization of a random matrix
31+
#:for rk,rt,ri in RC_KINDS_TYPES
32+
subroutine test_pivoting_qr_random_${ri}$(error)
33+
type(error_type), allocatable, intent(out) :: error
34+
35+
!-------------------------------
36+
!----- Tall matrix -----
37+
!-------------------------------
38+
block
39+
integer(ilp), parameter :: m = 15_ilp
40+
integer(ilp), parameter :: n = 4_ilp
41+
integer(ilp), parameter :: k = min(m,n)
42+
real(${rk}$), parameter :: tol = 100*sqrt(epsilon(0.0_${rk}$))
43+
${rt}$ :: a(m,n),aorig(m,n),q(m,m),r(m,n),qred(m,k),rred(k,n),qerr(m,6),rerr(6,n)
44+
real(${rk}$) :: rea(m,n),ima(m,n)
45+
integer(ilp) :: pivots(n), i, j
46+
integer(ilp) :: lwork
47+
${rt}$, allocatable :: work(:)
48+
type(linalg_state_type) :: state
49+
50+
call random_number(rea)
51+
#:if rt.startswith('complex')
52+
call random_number(ima)
53+
a = cmplx(rea,ima,kind=${rk}$)
54+
#:else
55+
a = rea
56+
#:endif
57+
aorig = a
58+
59+
! 1) QR factorization with full matrices. Input NaNs to be sure Q and R are OK on return
60+
q = ieee_value(0.0_${rk}$,ieee_quiet_nan)
61+
r = ieee_value(0.0_${rk}$,ieee_quiet_nan)
62+
call qr(a, q, r, pivots, err=state)
63+
64+
! Check return code
65+
call check(error,state%ok(),state%print())
66+
if (allocated(error)) return
67+
68+
! Check solution
69+
call check(error, all(abs(a(:, pivots)-matmul(q,r))<tol), 'converged solution (fulle)')
70+
if (allocated(error)) return
71+
72+
! 2) QR factorization with reduced matrices
73+
call qr(a, qred, rred, pivots, err=state)
74+
75+
! Check return code
76+
call check(error,state%ok(),state%print())
77+
if (allocated(error)) return
78+
79+
! Check solution
80+
call check(error, all(abs(a(:, pivots)-matmul(qred,rred))<tol), 'converged solution (reduced)')
81+
if (allocated(error)) return
82+
83+
! 3) overwrite A
84+
call qr(a, qred, rred, pivots, overwrite_a=.true., err=state)
85+
86+
! Check return code
87+
call check(error,state%ok(),state%print())
88+
if (allocated(error)) return
89+
90+
! Check solution
91+
call check(error, all(abs(aorig(:, pivots)-matmul(qred,rred))<tol), 'converged solution (overwrite A)')
92+
if (allocated(error)) return
93+
94+
! 4) External storage option
95+
a = aorig
96+
call qr_space(a, lwork, pivoting=.true.)
97+
allocate(work(lwork))
98+
call qr(a, q, r, pivots, storage=work, err=state)
99+
100+
! Check return code
101+
call check(error,state%ok(),state%print())
102+
if (allocated(error)) return
103+
104+
! Check solution
105+
call check(error, all(abs(a(:, pivots)-matmul(q,r))<tol), 'converged solution (external storage)')
106+
if (allocated(error)) return
107+
108+
! Check that an invalid problem size returns an error
109+
a = aorig
110+
call qr(a, qerr, rerr, pivots, err=state)
111+
call check(error,state%error(),'invalid matrix sizes')
112+
if (allocated(error)) return
113+
end block
114+
115+
!-------------------------------
116+
!----- Wide matrix -----
117+
!-------------------------------
118+
block
119+
integer(ilp), parameter :: m = 4_ilp
120+
integer(ilp), parameter :: n = 15_ilp
121+
integer(ilp), parameter :: k = min(m,n)
122+
real(${rk}$), parameter :: tol = 100*sqrt(epsilon(0.0_${rk}$))
123+
${rt}$ :: a(m,n),aorig(m,n),q(m,m),r(m,n),qred(m,k),rred(k,n),qerr(m,6),rerr(6,n)
124+
real(${rk}$) :: rea(m,n),ima(m,n)
125+
integer(ilp) :: pivots(n), i, j
126+
integer(ilp) :: lwork
127+
${rt}$, allocatable :: work(:)
128+
type(linalg_state_type) :: state
129+
130+
call random_number(rea)
131+
#:if rt.startswith('complex')
132+
call random_number(ima)
133+
a = cmplx(rea,ima,kind=${rk}$)
134+
#:else
135+
a = rea
136+
#:endif
137+
aorig = a
138+
139+
! 1) QR factorization with full matrices. Input NaNs to be sure Q and R are OK on return
140+
q = ieee_value(0.0_${rk}$,ieee_quiet_nan)
141+
r = ieee_value(0.0_${rk}$,ieee_quiet_nan)
142+
call qr(a, q, r, pivots, err=state)
143+
144+
! Check return code
145+
call check(error,state%ok(),state%print())
146+
if (allocated(error)) return
147+
148+
! Check solution
149+
call check(error, all(abs(a(:, pivots)-matmul(q,r))<tol), 'converged solution (fulle)')
150+
if (allocated(error)) return
151+
152+
! 2) QR factorization with reduced matrices
153+
call qr(a, qred, rred, pivots, err=state)
154+
155+
! Check return code
156+
call check(error,state%ok(),state%print())
157+
if (allocated(error)) return
158+
159+
! Check solution
160+
call check(error, all(abs(a(:, pivots)-matmul(qred,rred))<tol), 'converged solution (reduced)')
161+
if (allocated(error)) return
162+
163+
! 3) overwrite A
164+
call qr(a, qred, rred, pivots, overwrite_a=.true., err=state)
165+
166+
! Check return code
167+
call check(error,state%ok(),state%print())
168+
if (allocated(error)) return
169+
170+
! Check solution
171+
call check(error, all(abs(aorig(:, pivots)-matmul(qred,rred))<tol), 'converged solution (overwrite A)')
172+
if (allocated(error)) return
173+
174+
! 4) External storage option
175+
a = aorig
176+
call qr_space(a, lwork, pivoting=.true.)
177+
allocate(work(lwork))
178+
call qr(a, q, r, pivots, storage=work, err=state)
179+
180+
! Check return code
181+
call check(error,state%ok(),state%print())
182+
if (allocated(error)) return
183+
184+
! Check solution
185+
call check(error, all(abs(a(:, pivots)-matmul(q,r))<tol), 'converged solution (external storage)')
186+
if (allocated(error)) return
187+
188+
! Check that an invalid problem size returns an error
189+
a = aorig
190+
call qr(a, qerr, rerr, pivots, err=state)
191+
call check(error,state%error(),'invalid matrix sizes')
192+
if (allocated(error)) return
193+
end block
194+
end subroutine test_pivoting_qr_random_${ri}$
195+
196+
#:endfor
197+
198+
! gcc-15 bugfix utility
199+
subroutine add_test(tests,new_test)
200+
type(unittest_type), allocatable, intent(inout) :: tests(:)
201+
type(unittest_type), intent(in) :: new_test
202+
203+
integer :: n
204+
type(unittest_type), allocatable :: new_tests(:)
205+
206+
if (allocated(tests)) then
207+
n = size(tests)
208+
else
209+
n = 0
210+
end if
211+
212+
allocate(new_tests(n+1))
213+
if (n>0) new_tests(1:n) = tests(1:n)
214+
new_tests(1+n) = new_test
215+
call move_alloc(from=new_tests,to=tests)
216+
217+
end subroutine add_test
218+
219+
end module test_linalg_pivoting_qr
220+
221+
program test_pivoting_qr
222+
use, intrinsic :: iso_fortran_env, only : error_unit
223+
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
224+
use test_linalg_pivoting_qr, only : test_pivoting_qr_factorization
225+
implicit none
226+
integer :: stat, is
227+
type(testsuite_type), allocatable :: testsuites(:)
228+
character(len=*), parameter :: fmt = '("#", *(1x, a))'
229+
230+
stat = 0
231+
232+
testsuites = [ &
233+
new_testsuite("linalg_pivoting_qr", test_pivoting_qr_factorization) &
234+
]
235+
236+
do is = 1, size(testsuites)
237+
write(error_unit, fmt) "Testing:", testsuites(is)%name
238+
call run_testsuite(testsuites(is)%collect, error_unit, stat)
239+
end do
240+
241+
if (stat > 0) then
242+
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
243+
error stop
244+
end if
245+
end program test_pivoting_qr

0 commit comments

Comments
 (0)