Skip to content

Commit 6ea76e6

Browse files
author
Dalal Sukkari
committed
Fix in makefil, used impl namesapce and minor to the docs in geqrf_qdwh_full
1 parent 22e050a commit 6ea76e6

File tree

2 files changed

+38
-44
lines changed

2 files changed

+38
-44
lines changed

GNUmakefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ tester_src += \
639639
test/test_pbsv.cc \
640640
test/test_posv.cc \
641641
test/test_potri.cc \
642+
test/test_qdwh.cc \
642643
test/test_scale.cc \
643644
test/test_scale_row_col.cc \
644645
test/test_set.cc \
@@ -676,7 +677,6 @@ ifneq ($(have_fortran),)
676677
test/pdlantr.f \
677678
test/pclantr.f \
678679
test/pzlantr.f \
679-
test/test_qdwh.cc \
680680
# End. Add alphabetically, by base name after precision.
681681
endif
682682
endif

src/geqrf_qdwh_full.cc

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010

1111
namespace slate {
1212

13-
// specialization namespace differentiates, e.g.,
14-
// internal::geqrf_qdwh_full from internal::specialization::geqrf
15-
namespace internal {
16-
namespace specialization {
13+
namespace impl {
1714

1815
//------------------------------------------------------------------------------
1916
/// An auxiliary routine to find each rank's first (top-most) row
@@ -31,8 +28,9 @@ namespace specialization {
3128
/// @ingroup geqrf_qdwh_full_specialization
3229
///
3330
template <typename scalar_t>
34-
void geqrf_compute_first_indices(Matrix<scalar_t>& A_panel, int64_t k,
35-
std::vector< int64_t >& first_indices)
31+
void geqrf_compute_first_indices(
32+
Matrix<scalar_t>& A_panel, int64_t k,
33+
std::vector< int64_t >& first_indices)
3634
{
3735
// Find ranks in this column.
3836
std::set<int> ranks_set;
@@ -60,13 +58,13 @@ void geqrf_compute_first_indices(Matrix<scalar_t>& A_panel, int64_t k,
6058
///
6159
/// ColMajor layout is assumed
6260
///
63-
/// @ingroup geqrf_specialization
61+
/// @ingroup geqrf_impl
6462
///
6563
template <Target target, typename scalar_t>
66-
void geqrf_qdwh_full(slate::internal::TargetType<target>,
64+
void geqrf_qdwh_full(
6765
Matrix<scalar_t>& A,
6866
TriangularFactors<scalar_t>& T,
69-
int64_t ib, int max_panel_threads, int64_t lookahead)
67+
Options const& opts )
7068
{
7169
using BcastList = typename Matrix<scalar_t>::BcastList;
7270
using device_info_t = lapack::device_info_int;
@@ -78,7 +76,15 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
7876
const int priority_zero = 0;
7977
const int priority_one = 1;
8078
const int life_factor_one = 1;
81-
const bool set_hold = lookahead > 0; // Do tileGetAndHold in the bcast
79+
80+
// Options
81+
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );
82+
int64_t ib = get_option<int64_t>( opts, Option::InnerBlocking, 16 );
83+
int64_t max_panel_threads = std::max(omp_get_max_threads()/2, 1);
84+
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads,
85+
max_panel_threads );
86+
87+
bool set_hold = lookahead > 0; // Do tileGetAndHold in the bcast
8288

8389
int64_t A_mt = A.mt();
8490
int64_t A_nt = A.nt();
@@ -364,34 +370,17 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
364370
}
365371
}
366372

367-
} // namespace specialization
368-
} // namespace internal
369-
370-
//------------------------------------------------------------------------------
371-
/// Version with target as template parameter.
372-
/// @ingroup geqrf_qdwh_full_specialization
373-
///
374-
template <Target target, typename scalar_t>
375-
void geqrf_qdwh_full(Matrix<scalar_t>& A,
376-
TriangularFactors<scalar_t>& T,
377-
Options const& opts)
378-
{
379-
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );
380-
381-
int64_t ib = get_option<int64_t>( opts, Option::InnerBlocking, 16 );
382-
383-
int64_t max_panel_threads = std::max(omp_get_max_threads()/2, 1);
384-
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads, max_panel_threads );
385-
386-
internal::specialization::geqrf_qdwh_full(internal::TargetType<target>(),
387-
A, T,
388-
ib, max_panel_threads, lookahead);
389-
}
373+
} // namespace impl
390374

391375
//------------------------------------------------------------------------------
392-
/// Distributed parallel QR factorization.
376+
/// Distributed parallel customized QR factorization.
377+
/// Required for the QR-based iterations in the polar decomposition QDWH.
393378
///
394-
/// Computes a QR factorization of an m-by-n matrix $A$.
379+
/// Computes a QR factorization of m-by-n matrix $A$, m \ge 2n,
380+
/// and takes advantage of the trailing identity matrix structure.
381+
/// A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
382+
/// [ A1 ] identity matrix (n-by-n)
383+
/// Avoids doing computaions on the zero tiles below the diagonal of $A1$.
395384
/// The factorization has the form
396385
/// \[
397386
/// A = QR,
@@ -401,15 +390,16 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
401390
///
402391
/// Complexity (in real):
403392
/// - for $m \ge n$, $\approx 2 m n^{2} - \frac{2}{3} n^{3}$ flops;
404-
/// - for $m \le n$, $\approx 2 m^{2} n - \frac{2}{3} m^{3}$ flops;
405-
/// - for $m = n$, $\approx \frac{4}{3} n^{3}$ flops.
406393
/// .
407394
//------------------------------------------------------------------------------
408395
/// @tparam scalar_t
409396
/// One of float, double, std::complex<float>, std::complex<double>.
410397
//------------------------------------------------------------------------------
411398
/// @param[in,out] A
412-
/// On entry, the m-by-n matrix $A$.
399+
/// On entry, the m-by-n matrix $A$, m \ge 2n,
400+
/// A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
401+
/// [ A1 ] identity matrix (n-by-n)
402+
///
413403
/// On exit, the elements on and above the diagonal of the array contain
414404
/// the min(m,n)-by-n upper trapezoidal matrix $R$ (upper triangular
415405
/// if m >= n); the elements below the diagonal represent the unitary
@@ -437,7 +427,8 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
437427
/// @ingroup geqrf_computational
438428
///
439429
template <typename scalar_t>
440-
void geqrf_qdwh_full(Matrix<scalar_t>& A,
430+
void geqrf_qdwh_full(
431+
Matrix<scalar_t>& A,
441432
TriangularFactors<scalar_t>& T,
442433
Options const& opts)
443434
{
@@ -446,16 +437,19 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
446437
switch (target) {
447438
case Target::Host:
448439
case Target::HostTask:
449-
geqrf_qdwh_full<Target::HostTask>(A, T, opts);
440+
impl::geqrf_qdwh_full<Target::HostTask>( A, T, opts );
450441
break;
442+
451443
case Target::HostNest:
452-
geqrf_qdwh_full<Target::HostNest>(A, T, opts);
444+
impl::geqrf_qdwh_full<Target::HostNest>( A, T, opts );
453445
break;
446+
454447
case Target::HostBatch:
455-
geqrf_qdwh_full<Target::HostBatch>(A, T, opts);
448+
impl::geqrf_qdwh_full<Target::HostBatch>( A, T, opts );
456449
break;
450+
457451
case Target::Devices:
458-
geqrf_qdwh_full<Target::Devices>(A, T, opts);
452+
impl::geqrf_qdwh_full<Target::Devices>( A, T, opts );
459453
break;
460454
}
461455
// todo: return value for errors?

0 commit comments

Comments
 (0)