Skip to content

Commit 55bc413

Browse files
committed
Full implementation of pivoting QR.
1 parent 19cc0f0 commit 55bc413

File tree

2 files changed

+210
-2
lines changed

2 files changed

+210
-2
lines changed

src/stdlib_linalg.fypp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,17 @@ module stdlib_linalg
658658
!> State return flag. Returns an error if the query failed
659659
type(linalg_state_type), optional, intent(out) :: err
660660
end subroutine get_qr_${ri}$_workspace
661+
662+
pure module subroutine get_pivoting_qr_${ri}$_workspace(a, lwork, pivoting, err)
663+
!> Input matrix a[m, n]
664+
${rt}$, intent(in), target :: a(:, :)
665+
!> Minimum workspace size for both operations.
666+
integer(ilp), intent(out) :: lwork
667+
!> Pivoting flag.
668+
logical(lk), intent(in) :: pivoting
669+
!> State return flag. Returns an error if the query failed.
670+
type(linalg_state_type), optional, intent(out) :: err
671+
end subroutine get_pivoting_qr_${ri}$_workspace
661672
#:endfor
662673
end interface qr_space
663674

src/stdlib_linalg_qr.fypp

Lines changed: 199 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
33
submodule (stdlib_linalg) stdlib_linalg_qr
44
use stdlib_linalg_constants
5-
use stdlib_linalg_lapack, only: geqrf, orgqr, ungqr
6-
use stdlib_linalg_lapack_aux, only: handle_geqrf_info, handle_orgqr_info
5+
use stdlib_linalg_lapack, only: geqrf, geqp3, orgqr, ungqr
6+
use stdlib_linalg_lapack_aux, only: handle_geqrf_info, handle_orgqr_info, handle_geqp3_info
77
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
88
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
99
implicit none
@@ -220,4 +220,201 @@ submodule (stdlib_linalg) stdlib_linalg_qr
220220

221221
#:endfor
222222

223+
!---------------------------------------------------------
224+
!----- QR decomposition with column pivoting -----
225+
!---------------------------------------------------------
226+
227+
#:for rk, rt, ri in RC_KINDS_TYPES
228+
! Get workspace size for QR operations
229+
pure module subroutine get_pivoting_qr_${ri}$_workspace(a,lwork,pivoting,err)
230+
!> Input matrix a[m,n]
231+
${rt}$, intent(in), target :: a(:,:)
232+
!> Minimum workspace size for both operations
233+
integer(ilp), intent(out) :: lwork
234+
!> Pivoting flag.
235+
logical(lk), intent(in) :: pivoting
236+
!> State return flag. Returns an error if the query failed
237+
type(linalg_state_type), optional, intent(out) :: err
238+
239+
integer(ilp) :: m,n,k,info,lwork_qr,lwork_ord
240+
${rt}$ :: work_dummy(1),tau_dummy(1),a_dummy(1,1)
241+
integer(ilp) :: jpvt_dummy(1)
242+
real(${rk}$) :: rwork_dummy(1)
243+
type(linalg_state_type) :: err0
244+
245+
if (pivoting) then
246+
lwork = -1_ilp
247+
248+
!> Problem sizes
249+
m = size(a,1,kind=ilp)
250+
n = size(a,2,kind=ilp)
251+
k = min(m,n)
252+
253+
! QR space
254+
lwork_qr = -1_ilp
255+
#:if rt.startswith('complex')
256+
call geqp3(m, n, a_dummy, m, jpvt_dummy, tau_dummy, work_dummy, lwork_qr, rwork_dummy, info)
257+
#:else
258+
call geqp3(m, n, a_dummy, m, jpvt_dummy, tau_dummy, work_dummy, lwork_qr, info)
259+
#:endif
260+
call handle_geqp3_info(this, info, m, n, lwork_qr, err0)
261+
if (err0%error()) then
262+
call linalg_error_handling(err0,err)
263+
return
264+
endif
265+
lwork_qr = ceiling(real(work_dummy(1),kind=${rk}$),kind=ilp)
266+
267+
! Ordering space (for full factorization)
268+
lwork_ord = -1_ilp
269+
call #{if rt.startswith('complex')}# ungqr #{else}# orgqr #{endif}# &
270+
(m,m,k,a_dummy,m,tau_dummy,work_dummy,lwork_ord,info)
271+
call handle_orgqr_info(this,info,m,n,k,lwork_ord,err0)
272+
if (err0%error()) then
273+
call linalg_error_handling(err0,err)
274+
return
275+
endif
276+
lwork_ord = ceiling(real(work_dummy(1),kind=${rk}$),kind=ilp)
277+
278+
! Pick the largest size, so two operations can be performed with the same allocation
279+
lwork = max(lwork_qr, lwork_ord)
280+
else
281+
call qr_space(a, lwork, err)
282+
endif
283+
284+
end subroutine get_pivoting_qr_${ri}$_workspace
285+
286+
pure module subroutine stdlib_linalg_${ri}$_pivoting_qr(a, q, r, pivots, overwrite_a, storage, err)
287+
!> Input matrix a[m, n]
288+
${rt}$, intent(inout), target :: a(:, :)
289+
!> Orthogonal matrix Q ([m, m] or [m, k] if reduced)
290+
${rt}$, intent(out), contiguous, target :: q(:, :)
291+
!> Upper triangular matrix R ([m, n] or [k, n] if reduced)
292+
${rt}$, intent(out), contiguous, target :: r(:, :)
293+
!> Pivots.
294+
integer(ilp), intent(out) :: pivots(:)
295+
!> [optional] Can A data be overwritten and destroyed?
296+
logical(lk), optional, intent(in) :: overwrite_a
297+
!> [optional] Provide pre-allocated workspace, size to be checked with pivoting_qr_space.
298+
${rt}$, intent(out), optional, target :: storage(:)
299+
!> [optional] state return flag. On error if not requested, the code will stop.
300+
type(linalg_state_type), optional, intent(out) :: err
301+
302+
!> Local variables.
303+
type(linalg_state_type) :: err0
304+
integer(ilp) :: i, j, m, n, k, q1, q2, r1, r2, lda, lwork, info
305+
logical(lk) :: overwrite_a_, use_q_matrix, reduced
306+
${rt}$ :: r11
307+
${rt}$, parameter :: zero = 0.0_${rk}$
308+
${rt}$, pointer :: amat(:, :), tau(:), work(:)
309+
#:if rt.startswith('complex')
310+
real(${rk}$) :: rwork(2*size(a, 2, kind=ilp))
311+
#:endif
312+
313+
!> Problem sizes.
314+
m = size(a, 1, kind=ilp)
315+
n = size(a, 2, kind=ilp)
316+
k = min(m, n)
317+
q1 = size(q, 1, kind=ilp)
318+
q2 = size(q, 2, kind=ilp)
319+
r1 = size(r, 1, kind=ilp)
320+
r2 = size(r, 2, kind=ilp)
321+
pivots = 0_ilp
322+
323+
!> Full or thin QR factorization ?
324+
call check_problem_size(m, n, q1, q2, r1, r2, err0, reduced)
325+
if (err0%error()) then
326+
call linalg_error_handling(err0, err)
327+
return
328+
endif
329+
330+
!> Can Q be used as temporary storage for A,
331+
! to be destroyed by *GEQP3.
332+
use_q_matrix = q1 >= m .and. q2 >= n
333+
334+
!> Can A be overwritten ? (By default, no).
335+
overwrite_a_ = .false._lk
336+
if (present(overwrite_a) .and. .not. use_q_matrix) overwrite_a_ = overwrite_a
337+
338+
!> Initialize a temporary matrix or reuse available storage if possible.
339+
if (use_q_matrix) then
340+
amat => q
341+
q(:m, :n) = a
342+
else if (overwrite_a_) then
343+
amat => a
344+
else
345+
allocate(amat(m, n), source=a)
346+
endif
347+
lda = size(amat, 1, kind=ilp)
348+
349+
!> Store the elementary reflectors.
350+
if (.not. use_q_matrix) then
351+
! Q is not being used as the storage matrix.
352+
tau(1:k) => q(1:k, 1)
353+
else
354+
! R has unused contiguous storage in the 1st column, except for the
355+
! r11 element. Use the full column and store it in a dummy variable.
356+
tau(1:k) => r(1:k, 1)
357+
endif
358+
359+
! Retrieve workspace size.
360+
call get_pivoting_qr_${ri}$_workspace(a, lwork, .true., err0)
361+
362+
if (err0%ok()) then
363+
364+
if (present(storage)) then
365+
work => storage
366+
else
367+
allocate(work(lwork))
368+
endif
369+
if (.not. size(work, kind=ilp) >= lwork) then
370+
err0 = linalg_state_type(this, LINALG_ERROR, "insufficient workspace: should be at least ", lwork)
371+
call linalg_error_handling(err0, err)
372+
return
373+
endif
374+
375+
! Compute factorization.
376+
#:if rt.startswith('complex')
377+
call geqp3(m, n, amat, m, pivots, tau, work, lwork, rwork, info)
378+
#:else
379+
call geqp3(m, n, amat, m, pivots, tau, work, lwork, info)
380+
#:endif
381+
call handle_geqp3_info(this, info, m, n, lwork, err0)
382+
383+
if (err0%ok()) then
384+
! Get R matrix out before overwritten.
385+
! Do not copy the first column at this stage: it may be used by `tau`
386+
r11 = amat(1, 1)
387+
do concurrent(i=1:min(r1, m), j=2:n)
388+
r(i, j) = merge(amat(i, j), zero, i <= j)
389+
enddo
390+
391+
! Convert K elementary reflectors tau(1:k) -> orthogonal matrix Q
392+
call #{if rt.startswith('complex')}# ungqr #{else}# orgqr #{endif}# &
393+
(q1,q2,k,amat,lda,tau,work,lwork,info)
394+
call handle_orgqr_info(this,info,m,n,k,lwork,err0)
395+
396+
! Copy result back to Q
397+
if (.not.use_q_matrix) q = amat(:q1,:q2)
398+
399+
! Copy first column of R
400+
r(1,1) = r11
401+
r(2:r1,1) = zero
402+
403+
! Ensure last m-n rows of R are zeros,
404+
! if full matrices were provided
405+
if (.not.reduced) r(k+1:m,1:n) = zero
406+
endif
407+
408+
if (.not. present(storage)) deallocate(work)
409+
410+
endif
411+
412+
if (.not.(use_q_matrix.or.overwrite_a_)) deallocate(amat)
413+
414+
! Process output and return
415+
call linalg_error_handling(err0,err)
416+
417+
end subroutine stdlib_linalg_${ri}$_pivoting_qr
418+
#:endfor
419+
223420
end submodule stdlib_linalg_qr

0 commit comments

Comments
 (0)