Skip to content

Commit fc54644

Browse files
committed
Added test for constrained lstsq.
1 parent 36c17fb commit fc54644

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

test/linalg/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(
66
"test_linalg_inverse.fypp"
77
"test_linalg_pseudoinverse.fypp"
88
"test_linalg_lstsq.fypp"
9+
"test_linalg_constrained_lstsq.fypp"
910
"test_linalg_norm.fypp"
1011
"test_linalg_mnorm.fypp"
1112
"test_linalg_determinant.fypp"
@@ -41,6 +42,7 @@ ADDTEST(linalg_norm)
4142
ADDTEST(linalg_mnorm)
4243
ADDTEST(linalg_solve)
4344
ADDTEST(linalg_lstsq)
45+
ADDTEST(linalg_constrained_lstsq)
4446
ADDTEST(linalg_qr)
4547
ADDTEST(linalg_schur)
4648
ADDTEST(linalg_solve_iterative)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#:include "common.fypp"
2+
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
3+
! Test least squares solver
4+
module test_linalg_constrained_least_squares
5+
use testdrive, only: error_type, check, new_unittest, unittest_type
6+
use stdlib_linalg_constants
7+
use stdlib_linalg, only: constrained_lstsq, solve_constrained_lstsq, constrained_lstsq_space
8+
use stdlib_linalg_state, only: linalg_state_type
9+
10+
implicit none (type,external)
11+
private
12+
13+
public :: test_constrained_least_squares
14+
15+
contains
16+
17+
!> Solve sample least squares problems
18+
subroutine test_constrained_least_squares(tests)
19+
!> Collection of tests
20+
type(unittest_type), allocatable, intent(out) :: tests(:)
21+
22+
allocate(tests(0))
23+
24+
#:for rk,rt,ri in REAL_KINDS_TYPES
25+
call add_test(tests,new_unittest("constrained_least_squares_randm_${ri}$",test_constrained_lstsq_random_${ri}$))
26+
#:endfor
27+
28+
end subroutine test_constrained_least_squares
29+
30+
#:for rk,rt,ri in REAL_KINDS_TYPES
31+
!> Fit from random array
32+
subroutine test_constrained_lstsq_random_${ri}$(error)
33+
type(error_type), allocatable, intent(out) :: error
34+
type(linalg_state_type) :: state
35+
integer(ilp), parameter :: m=5, n=4, p=3
36+
!> Least-squares cost.
37+
${rt}$ :: A(m, n), b(m)
38+
!> Equality constraints.
39+
${rt}$ :: C(p, n), d(p)
40+
!> Solution.
41+
${rt}$ :: x(n), x_true(n)
42+
!> Workspace.
43+
integer(ilp) :: lwork
44+
${rt}$, allocatable :: work(:)
45+
46+
!> Least-squares cost.
47+
A(1, :) = [1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$]
48+
A(2, :) = [1.0_${rk}$, 3.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$]
49+
A(3, :) = [1.0_${rk}$, -1.0_${rk}$, 3.0_${rk}$, 1.0_${rk}$]
50+
A(4, :) = [1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$, 3.0_${rk}$]
51+
A(5, :) = [1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$, -1.0_${rk}$]
52+
53+
b = [2.0_${rk}$, 1.0_${rk}$, 6.0_${rk}$, 3.0_${rk}$, 1.0_${rk}$]
54+
55+
!> Equality constraints.
56+
C(1, :) = [1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$, -1.0_${rk}$]
57+
C(2, :) = [1.0_${rk}$, -1.0_${rk}$, 1.0_${rk}$, 1.0_${rk}$]
58+
C(3, :) = [1.0_${rk}$, 1.0_${rk}$, -1.0_${rk}$, 1.0_${rk}$]
59+
60+
d = [1.0_${rk}$, 3.0_${rk}$, -1.0_${rk}$]
61+
62+
!----- Function interface -----
63+
x = constrained_lstsq(A, b, C, d, err=state)
64+
x_true = [0.5_${rk}$, -0.5_${rk}$, 1.5_${rk}$, 0.5_${rk}$]
65+
66+
call check(error, state%ok(), state%print())
67+
if (allocated(error)) return
68+
69+
call check(error, all(abs(x-x_true) < 1.0e-4_${rk}$), 'Solver converged')
70+
if (allocated(error)) return
71+
72+
!----- Subroutine interface -----
73+
call solve_constrained_lstsq(A, b, C, d, x, err=state)
74+
75+
call check(error, state%ok(), state%print())
76+
if (allocated(error)) return
77+
78+
call check(error, all(abs(x-x_true) < 1.0e-4_${rk}$), 'Solver converged')
79+
if (allocated(error)) return
80+
81+
!----- Pre-allocated storage -----
82+
call constrained_lstsq_space(A, b, C, d, lwork, err=state)
83+
call check(error, state%ok(), state%print())
84+
if (allocated(error)) return
85+
allocate(work(lwork))
86+
call solve_constrained_lstsq(A, b, C, d, x, storage=work, err=state)
87+
88+
call check(error, state%ok(), state%print())
89+
if (allocated(error)) return
90+
91+
call check(error, all(abs(x-x_true) < 1.0e-4_${rk}$), 'Solver converged')
92+
if (allocated(error)) return
93+
94+
!----- Overwrite matrices (performances) -----
95+
call solve_constrained_lstsq(A, b, C, d, x, storage=work, overwrite_matrices=.true., err=state)
96+
97+
call check(error, state%ok(), state%print())
98+
if (allocated(error)) return
99+
100+
call check(error, all(abs(x-x_true) < 1.0e-4_${rk}$), 'Solver converged')
101+
if (allocated(error)) return
102+
103+
end subroutine test_constrained_lstsq_random_${ri}$
104+
105+
#:endfor
106+
107+
! gcc-15 bugfix utility
108+
subroutine add_test(tests,new_test)
109+
type(unittest_type), allocatable, intent(inout) :: tests(:)
110+
type(unittest_type), intent(in) :: new_test
111+
112+
integer :: n
113+
type(unittest_type), allocatable :: new_tests(:)
114+
115+
if (allocated(tests)) then
116+
n = size(tests)
117+
else
118+
n = 0
119+
end if
120+
121+
allocate(new_tests(n+1))
122+
if (n>0) new_tests(1:n) = tests(1:n)
123+
new_tests(1+n) = new_test
124+
call move_alloc(from=new_tests,to=tests)
125+
126+
end subroutine add_test
127+
128+
end module test_linalg_constrained_least_squares
129+
130+
program test_constrained_lstsq
131+
use, intrinsic :: iso_fortran_env, only : error_unit
132+
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
133+
use test_linalg_constrained_least_squares, only : test_constrained_least_squares
134+
implicit none
135+
integer :: stat, is
136+
type(testsuite_type), allocatable :: testsuites(:)
137+
character(len=*), parameter :: fmt = '("#", *(1x, a))'
138+
139+
stat = 0
140+
141+
testsuites = [ &
142+
new_testsuite("linalg_constrained_least_squares", test_constrained_least_squares) &
143+
]
144+
145+
do is = 1, size(testsuites)
146+
write(error_unit, fmt) "Testing:", testsuites(is)%name
147+
call run_testsuite(testsuites(is)%collect, error_unit, stat)
148+
end do
149+
150+
if (stat > 0) then
151+
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
152+
error stop
153+
end if
154+
end program test_constrained_lstsq

0 commit comments

Comments
 (0)