Skip to content

Commit 728d221

Browse files
committed
Compilable implementation.
1 parent a8873d9 commit 728d221

File tree

2 files changed

+147
-5
lines changed

2 files changed

+147
-5
lines changed

src/stdlib_linalg.fypp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ module stdlib_linalg
3838
public :: operator(.pinv.)
3939
public :: lstsq
4040
public :: lstsq_space
41+
public :: constrained_lstsq
42+
public :: constrained_lstsq_space
4143
public :: norm
4244
public :: mnorm
4345
public :: get_norm
4446
public :: solve
4547
public :: solve_lu
4648
public :: solve_lstsq
49+
public :: solve_constrained_lstsq
4750
public :: trace
4851
public :: svd
4952
public :: svdvals
@@ -607,7 +610,7 @@ module stdlib_linalg
607610
!> Solution vector.
608611
${rt}$, intent(out) :: x(:)
609612
!> [optional] Storage.
610-
${rt}$, optional, intent(out) :: storage(:)
613+
${rt}$, optional, intent(out), target :: storage(:)
611614
!> [optional] Can A and C be overwritten?
612615
logical(lk), optional, intent(in) :: overwrite_matrices
613616
!> [optional] State return flag. On error if not requested, the code stops.
@@ -618,10 +621,11 @@ module stdlib_linalg
618621

619622
interface constrained_lstsq_space
620623
#:for rk, rt, ri in RC_KINDS_TYPES
621-
pure module subroutine stdlib_linalg_${ri}$_constrained_lstsq_space(A, b, C, d, lwork)
624+
module subroutine stdlib_linalg_${ri}$_constrained_lstsq_space(A, b, C, d, lwork, err)
622625
${rt}$, intent(in), target :: A(:, :), C(:, :)
623626
${rt}$, intent(in), target :: b(:), d(:)
624627
integer(ilp), intent(out) :: lwork
628+
type(linalg_state_type), optional, intent(out) :: err
625629
end subroutine stdlib_linalg_${ri}$_constrained_lstsq_space
626630
#:endfor
627631
end interface

src/stdlib_linalg_least_squares.fypp

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
submodule (stdlib_linalg) stdlib_linalg_least_squares
88
!! Least-squares solution to Ax=b
99
use stdlib_linalg_constants
10-
use stdlib_linalg_lapack, only: gelsd, stdlib_ilaenv
11-
use stdlib_linalg_lapack_aux, only: handle_gelsd_info
10+
use stdlib_linalg_lapack, only: gelsd, gglse, stdlib_ilaenv
11+
use stdlib_linalg_lapack_aux, only: handle_gelsd_info, handle_gglse_info
1212
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
1313
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
1414
implicit none
@@ -170,7 +170,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
170170
#:if rt.startswith('c')
171171
!> [optional] complex working storage space
172172
${rt}$, optional, intent(inout), target :: cmpl_storage(:)
173-
#:endif
173+
#:endif
174174
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
175175
real(${rk}$), optional, intent(in) :: cond
176176
!> [optional] list of singular values [min(m,n)], in descending magnitude order, returned by the SVD
@@ -363,4 +363,142 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
363363
endif
364364
end function ilog2
365365

366+
!-------------------------------------------------------------
367+
!----- Equality-constrained Least-Squares solver -----
368+
!-------------------------------------------------------------
369+
370+
pure subroutine check_problem_size(ma, na, mb, mc, nc, md, mx, err)
371+
integer(ilp), intent(in) :: ma, na, mb, mc, nc, md, mx
372+
type(linalg_state_type), intent(out) :: err
373+
374+
! Check sizes.
375+
if (ma < 1 .or. na < 1) then
376+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Invalid matrix size a(m, n) =', [ma, na])
377+
return
378+
else if (mc < 1 .or. nc < 1) then
379+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Invalid matrix size c(m, n) =', [mc, nc])
380+
else if (na /= nc) then
381+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Matrix A and matrix C have inconsistent number of columns.')
382+
else if (mb /= ma) then
383+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Size(b) inconsistent with number of rows in a, size(b) =', mb)
384+
else if (md /= mc) then
385+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Size(d) inconsistent with number of rows in c, size(d) =', md)
386+
else if (na /= mx) then
387+
err = linalg_state_type(this, LINALG_VALUE_ERROR, 'Size(x) inconsistent with number of columns of a, size(x) =', mx)
388+
endif
389+
end subroutine check_problem_size
390+
391+
#:for rk, rt, ri in RC_KINDS_TYPES
392+
module subroutine stdlib_linalg_${ri}$_constrained_lstsq_space(A, b, C, d, lwork, err)
393+
${rt}$, intent(in), target :: A(:, :), C(:, :)
394+
${rt}$, intent(in), target :: b(:), d(:)
395+
integer(ilp), intent(out) :: lwork
396+
type(linalg_state_type), optional, intent(out) :: err
397+
end subroutine stdlib_linalg_${ri}$_constrained_lstsq_space
398+
399+
module subroutine stdlib_linalg_${ri}$_solve_constrained_lstsq(A, b, C, d, x, storage, overwrite_matrices, err)
400+
!> Input matrices.
401+
${rt}$, intent(inout), target :: A(:, :), C(:, :)
402+
!> Right-hand side vectors.
403+
${rt}$, intent(inout), target :: b(:), d(:)
404+
!> Solution vector.
405+
${rt}$, intent(out) :: x(:)
406+
!> [optional] Storage.
407+
${rt}$, optional, intent(out), target :: storage(:)
408+
!> [optional] Can A, b, C, and d be overwritten?
409+
logical(lk), optional, intent(in) :: overwrite_matrices
410+
!> [optional] State return flag.
411+
type(linalg_state_type), optional, intent(out) :: err
412+
413+
! Local variables.
414+
type(linalg_state_type) :: err0
415+
integer(ilp) :: ma, na, mb
416+
integer(ilp) :: mc, nc, md
417+
integer(ilp) :: mx
418+
logical(lk) :: overwrite_matrices_
419+
${rt}$, pointer :: amat(:, :), bvec(:)
420+
${rt}$, pointer :: cmat(:, :), dvec(:)
421+
! LAPACK related.
422+
integer(ilp) :: lwork, info
423+
${rt}$, pointer :: work(:)
424+
425+
!> Check dimensions.
426+
ma = size(A, 1, kind=ilp) ; na = size(A, 2, kind=ilp)
427+
mc = size(C, 1, kind=ilp) ; nc = size(C, 2, kind=ilp)
428+
mb = size(b, kind=ilp) ; md = size(d, kind=ilp) ; mx = size(x, kind=ilp)
429+
call check_problem_size(ma, na, mb, mc, nc, md, mx, err0)
430+
if (err0%error()) then
431+
call linalg_error_handling(err0, err)
432+
return
433+
endif
434+
435+
!> Check if matrices can be overwritten.
436+
overwrite_matrices_ = optval(overwrite_matrices, .false._lk)
437+
438+
!> Allocate matrices.
439+
if (overwrite_matrices_) then
440+
amat => a
441+
bvec => b
442+
cmat => c
443+
dvec => d
444+
else
445+
allocate(amat(ma, na), source=a)
446+
allocate(bvec(mb), source=b)
447+
allocate(cmat(mc, nc), source=c)
448+
allocate(dvec(md), source=d)
449+
endif
450+
451+
!> Retrieve workspace size.
452+
call stdlib_linalg_${ri}$_constrained_lstsq_space(A, b, C, d, lwork, err0)
453+
454+
if (err0%ok()) then
455+
!> Workspace.
456+
if (present(storage)) then
457+
work => storage
458+
else
459+
allocate(work(lwork))
460+
endif
461+
if (size(work, kind=ilp) < lwork) then
462+
err0 = linalg_state_type(this, LINALG_ERROR, 'Insufficient workspace. Should be at least ', lwork)
463+
call linalg_error_handling(err0, err)
464+
return
465+
endif
466+
467+
!> Compute constrained lstsq solution.
468+
call gglse(ma, na, mc, amat, ma, cmat, mc, bvec, dvec, x, work, lwork, info)
469+
call handle_gglse_info(this, info, ma, na, mc, err0)
470+
471+
!> Deallocate.
472+
deallocate(work)
473+
endif
474+
475+
if (.not. overwrite_matrices_) then
476+
deallocate(amat, bvec, cmat, dvec)
477+
endif
478+
479+
call linalg_error_handling(err0, err)
480+
481+
end subroutine stdlib_linalg_${ri}$_solve_constrained_lstsq
482+
483+
module function stdlib_linalg_${ri}$_constrained_lstsq(A, b, C, d, overwrite_matrices, err) result(x)
484+
!> Input matrices.
485+
${rt}$, intent(inout), target :: A(:, :), C(:, :)
486+
!> Right-hand side vectors.
487+
${rt}$, intent(inout), target :: b(:), d(:)
488+
!> [optional] Can A, b, C, d be overwritten?
489+
logical(lk), optional, intent(in) :: overwrite_matrices
490+
!> [optional] State return flag.
491+
type(linalg_state_type), optional, intent(out) :: err
492+
!> Solution of the constrained least-squares problem.
493+
${rt}$, allocatable, target :: x(:)
494+
495+
! Local variables.
496+
integer(ilp) :: n
497+
498+
n = size(A, 2, kind=ilp)
499+
allocate(x(n))
500+
call stdlib_linalg_${ri}$_solve_constrained_lstsq(A, b, C, d, x, overwrite_matrices=overwrite_matrices, err=err)
501+
end function stdlib_linalg_${ri}$_constrained_lstsq
502+
#:endfor
503+
366504
end submodule stdlib_linalg_least_squares

0 commit comments

Comments
 (0)