|
2 | 2 | #:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES |
3 | 3 | submodule (stdlib_linalg) stdlib_linalg_qr |
4 | 4 | use stdlib_linalg_constants |
5 | | - use stdlib_linalg_lapack, only: geqrf, orgqr, ungqr |
6 | | - use stdlib_linalg_lapack_aux, only: handle_geqrf_info, handle_orgqr_info |
| 5 | + use stdlib_linalg_lapack, only: geqrf, geqp3, orgqr, ungqr |
| 6 | + use stdlib_linalg_lapack_aux, only: handle_geqrf_info, handle_orgqr_info, handle_geqp3_info |
7 | 7 | use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, & |
8 | 8 | LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR |
9 | 9 | implicit none |
@@ -220,4 +220,201 @@ submodule (stdlib_linalg) stdlib_linalg_qr |
220 | 220 |
|
221 | 221 | #:endfor |
222 | 222 |
|
| 223 | + !--------------------------------------------------------- |
| 224 | + !----- QR decomposition with column pivoting ----- |
| 225 | + !--------------------------------------------------------- |
| 226 | + |
| 227 | + #:for rk, rt, ri in RC_KINDS_TYPES |
| 228 | + ! Get workspace size for QR operations |
| 229 | + pure module subroutine get_pivoting_qr_${ri}$_workspace(a,lwork,pivoting,err) |
| 230 | + !> Input matrix a[m,n] |
| 231 | + ${rt}$, intent(in), target :: a(:,:) |
| 232 | + !> Minimum workspace size for both operations |
| 233 | + integer(ilp), intent(out) :: lwork |
| 234 | + !> Pivoting flag. |
| 235 | + logical(lk), intent(in) :: pivoting |
| 236 | + !> State return flag. Returns an error if the query failed |
| 237 | + type(linalg_state_type), optional, intent(out) :: err |
| 238 | + |
| 239 | + integer(ilp) :: m,n,k,info,lwork_qr,lwork_ord |
| 240 | + ${rt}$ :: work_dummy(1),tau_dummy(1),a_dummy(1,1) |
| 241 | + integer(ilp) :: jpvt_dummy(1) |
| 242 | + real(${rk}$) :: rwork_dummy(1) |
| 243 | + type(linalg_state_type) :: err0 |
| 244 | + |
| 245 | + if (pivoting) then |
| 246 | + lwork = -1_ilp |
| 247 | + |
| 248 | + !> Problem sizes |
| 249 | + m = size(a,1,kind=ilp) |
| 250 | + n = size(a,2,kind=ilp) |
| 251 | + k = min(m,n) |
| 252 | + |
| 253 | + ! QR space |
| 254 | + lwork_qr = -1_ilp |
| 255 | + #:if rt.startswith('complex') |
| 256 | + call geqp3(m, n, a_dummy, m, jpvt_dummy, tau_dummy, work_dummy, lwork_qr, rwork_dummy, info) |
| 257 | + #:else |
| 258 | + call geqp3(m, n, a_dummy, m, jpvt_dummy, tau_dummy, work_dummy, lwork_qr, info) |
| 259 | + #:endif |
| 260 | + call handle_geqp3_info(this, info, m, n, lwork_qr, err0) |
| 261 | + if (err0%error()) then |
| 262 | + call linalg_error_handling(err0,err) |
| 263 | + return |
| 264 | + endif |
| 265 | + lwork_qr = ceiling(real(work_dummy(1),kind=${rk}$),kind=ilp) |
| 266 | + |
| 267 | + ! Ordering space (for full factorization) |
| 268 | + lwork_ord = -1_ilp |
| 269 | + call #{if rt.startswith('complex')}# ungqr #{else}# orgqr #{endif}# & |
| 270 | + (m,m,k,a_dummy,m,tau_dummy,work_dummy,lwork_ord,info) |
| 271 | + call handle_orgqr_info(this,info,m,n,k,lwork_ord,err0) |
| 272 | + if (err0%error()) then |
| 273 | + call linalg_error_handling(err0,err) |
| 274 | + return |
| 275 | + endif |
| 276 | + lwork_ord = ceiling(real(work_dummy(1),kind=${rk}$),kind=ilp) |
| 277 | + |
| 278 | + ! Pick the largest size, so two operations can be performed with the same allocation |
| 279 | + lwork = max(lwork_qr, lwork_ord) |
| 280 | + else |
| 281 | + call qr_space(a, lwork, err) |
| 282 | + endif |
| 283 | + |
| 284 | + end subroutine get_pivoting_qr_${ri}$_workspace |
| 285 | + |
| 286 | + pure module subroutine stdlib_linalg_${ri}$_pivoting_qr(a, q, r, pivots, overwrite_a, storage, err) |
| 287 | + !> Input matrix a[m, n] |
| 288 | + ${rt}$, intent(inout), target :: a(:, :) |
| 289 | + !> Orthogonal matrix Q ([m, m] or [m, k] if reduced) |
| 290 | + ${rt}$, intent(out), contiguous, target :: q(:, :) |
| 291 | + !> Upper triangular matrix R ([m, n] or [k, n] if reduced) |
| 292 | + ${rt}$, intent(out), contiguous, target :: r(:, :) |
| 293 | + !> Pivots. |
| 294 | + integer(ilp), intent(out) :: pivots(:) |
| 295 | + !> [optional] Can A data be overwritten and destroyed? |
| 296 | + logical(lk), optional, intent(in) :: overwrite_a |
| 297 | + !> [optional] Provide pre-allocated workspace, size to be checked with pivoting_qr_space. |
| 298 | + ${rt}$, intent(out), optional, target :: storage(:) |
| 299 | + !> [optional] state return flag. On error if not requested, the code will stop. |
| 300 | + type(linalg_state_type), optional, intent(out) :: err |
| 301 | + |
| 302 | + !> Local variables. |
| 303 | + type(linalg_state_type) :: err0 |
| 304 | + integer(ilp) :: i, j, m, n, k, q1, q2, r1, r2, lda, lwork, info |
| 305 | + logical(lk) :: overwrite_a_, use_q_matrix, reduced |
| 306 | + ${rt}$ :: r11 |
| 307 | + ${rt}$, parameter :: zero = 0.0_${rk}$ |
| 308 | + ${rt}$, pointer :: amat(:, :), tau(:), work(:) |
| 309 | + #:if rt.startswith('complex') |
| 310 | + real(${rk}$) :: rwork(2*size(a, 2, kind=ilp)) |
| 311 | + #:endif |
| 312 | + |
| 313 | + !> Problem sizes. |
| 314 | + m = size(a, 1, kind=ilp) |
| 315 | + n = size(a, 2, kind=ilp) |
| 316 | + k = min(m, n) |
| 317 | + q1 = size(q, 1, kind=ilp) |
| 318 | + q2 = size(q, 2, kind=ilp) |
| 319 | + r1 = size(r, 1, kind=ilp) |
| 320 | + r2 = size(r, 2, kind=ilp) |
| 321 | + pivots = 0_ilp |
| 322 | + |
| 323 | + !> Full or thin QR factorization ? |
| 324 | + call check_problem_size(m, n, q1, q2, r1, r2, err0, reduced) |
| 325 | + if (err0%error()) then |
| 326 | + call linalg_error_handling(err0, err) |
| 327 | + return |
| 328 | + endif |
| 329 | + |
| 330 | + !> Can Q be used as temporary storage for A, |
| 331 | + ! to be destroyed by *GEQP3. |
| 332 | + use_q_matrix = q1 >= m .and. q2 >= n |
| 333 | + |
| 334 | + !> Can A be overwritten ? (By default, no). |
| 335 | + overwrite_a_ = .false._lk |
| 336 | + if (present(overwrite_a) .and. .not. use_q_matrix) overwrite_a_ = overwrite_a |
| 337 | + |
| 338 | + !> Initialize a temporary matrix or reuse available storage if possible. |
| 339 | + if (use_q_matrix) then |
| 340 | + amat => q |
| 341 | + q(:m, :n) = a |
| 342 | + else if (overwrite_a_) then |
| 343 | + amat => a |
| 344 | + else |
| 345 | + allocate(amat(m, n), source=a) |
| 346 | + endif |
| 347 | + lda = size(amat, 1, kind=ilp) |
| 348 | + |
| 349 | + !> Store the elementary reflectors. |
| 350 | + if (.not. use_q_matrix) then |
| 351 | + ! Q is not being used as the storage matrix. |
| 352 | + tau(1:k) => q(1:k, 1) |
| 353 | + else |
| 354 | + ! R has unused contiguous storage in the 1st column, except for the |
| 355 | + ! r11 element. Use the full column and store it in a dummy variable. |
| 356 | + tau(1:k) => r(1:k, 1) |
| 357 | + endif |
| 358 | + |
| 359 | + ! Retrieve workspace size. |
| 360 | + call get_pivoting_qr_${ri}$_workspace(a, lwork, .true., err0) |
| 361 | + |
| 362 | + if (err0%ok()) then |
| 363 | + |
| 364 | + if (present(storage)) then |
| 365 | + work => storage |
| 366 | + else |
| 367 | + allocate(work(lwork)) |
| 368 | + endif |
| 369 | + if (.not. size(work, kind=ilp) >= lwork) then |
| 370 | + err0 = linalg_state_type(this, LINALG_ERROR, "insufficient workspace: should be at least ", lwork) |
| 371 | + call linalg_error_handling(err0, err) |
| 372 | + return |
| 373 | + endif |
| 374 | + |
| 375 | + ! Compute factorization. |
| 376 | + #:if rt.startswith('complex') |
| 377 | + call geqp3(m, n, amat, m, pivots, tau, work, lwork, rwork, info) |
| 378 | + #:else |
| 379 | + call geqp3(m, n, amat, m, pivots, tau, work, lwork, info) |
| 380 | + #:endif |
| 381 | + call handle_geqp3_info(this, info, m, n, lwork, err0) |
| 382 | + |
| 383 | + if (err0%ok()) then |
| 384 | + ! Get R matrix out before overwritten. |
| 385 | + ! Do not copy the first column at this stage: it may be used by `tau` |
| 386 | + r11 = amat(1, 1) |
| 387 | + do concurrent(i=1:min(r1, m), j=2:n) |
| 388 | + r(i, j) = merge(amat(i, j), zero, i <= j) |
| 389 | + enddo |
| 390 | + |
| 391 | + ! Convert K elementary reflectors tau(1:k) -> orthogonal matrix Q |
| 392 | + call #{if rt.startswith('complex')}# ungqr #{else}# orgqr #{endif}# & |
| 393 | + (q1,q2,k,amat,lda,tau,work,lwork,info) |
| 394 | + call handle_orgqr_info(this,info,m,n,k,lwork,err0) |
| 395 | + |
| 396 | + ! Copy result back to Q |
| 397 | + if (.not.use_q_matrix) q = amat(:q1,:q2) |
| 398 | + |
| 399 | + ! Copy first column of R |
| 400 | + r(1,1) = r11 |
| 401 | + r(2:r1,1) = zero |
| 402 | + |
| 403 | + ! Ensure last m-n rows of R are zeros, |
| 404 | + ! if full matrices were provided |
| 405 | + if (.not.reduced) r(k+1:m,1:n) = zero |
| 406 | + endif |
| 407 | + |
| 408 | + if (.not. present(storage)) deallocate(work) |
| 409 | + |
| 410 | + endif |
| 411 | + |
| 412 | + if (.not.(use_q_matrix.or.overwrite_a_)) deallocate(amat) |
| 413 | + |
| 414 | + ! Process output and return |
| 415 | + call linalg_error_handling(err0,err) |
| 416 | + |
| 417 | + end subroutine stdlib_linalg_${ri}$_pivoting_qr |
| 418 | + #:endfor |
| 419 | + |
223 | 420 | end submodule stdlib_linalg_qr |
0 commit comments