1010
1111namespace slate {
1212
13- // specialization namespace differentiates, e.g.,
14- // internal::geqrf_qdwh_full from internal::specialization::geqrf
15- namespace internal {
16- namespace specialization {
13+ namespace impl {
1714
1815// ------------------------------------------------------------------------------
1916// / An auxiliary routine to find each rank's first (top-most) row
@@ -31,8 +28,9 @@ namespace specialization {
3128// / @ingroup geqrf_qdwh_full_specialization
3229// /
3330template <typename scalar_t >
34- void geqrf_compute_first_indices (Matrix<scalar_t >& A_panel, int64_t k,
35- std::vector< int64_t >& first_indices)
31+ void geqrf_compute_first_indices (
32+ Matrix<scalar_t >& A_panel, int64_t k,
33+ std::vector< int64_t >& first_indices)
3634{
3735 // Find ranks in this column.
3836 std::set<int > ranks_set;
@@ -60,13 +58,13 @@ void geqrf_compute_first_indices(Matrix<scalar_t>& A_panel, int64_t k,
6058// /
6159// / ColMajor layout is assumed
6260// /
63- // / @ingroup geqrf_specialization
61+ // / @ingroup geqrf_impl
6462// /
6563template <Target target, typename scalar_t >
66- void geqrf_qdwh_full (slate::internal::TargetType<target>,
64+ void geqrf_qdwh_full (
6765 Matrix<scalar_t >& A,
6866 TriangularFactors<scalar_t >& T,
69- int64_t ib, int max_panel_threads, int64_t lookahead )
67+ Options const & opts )
7068{
7169 using BcastList = typename Matrix<scalar_t >::BcastList;
7270 using device_info_t = lapack::device_info_int;
@@ -78,7 +76,15 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
7876 const int priority_zero = 0 ;
7977 const int priority_one = 1 ;
8078 const int life_factor_one = 1 ;
81- const bool set_hold = lookahead > 0 ; // Do tileGetAndHold in the bcast
79+
80+ // Options
81+ int64_t lookahead = get_option<int64_t >( opts, Option::Lookahead, 1 );
82+ int64_t ib = get_option<int64_t >( opts, Option::InnerBlocking, 16 );
83+ int64_t max_panel_threads = std::max (omp_get_max_threads ()/2 , 1 );
84+ max_panel_threads = get_option<int64_t >( opts, Option::MaxPanelThreads,
85+ max_panel_threads );
86+
87+ bool set_hold = lookahead > 0 ; // Do tileGetAndHold in the bcast
8288
8389 int64_t A_mt = A.mt ();
8490 int64_t A_nt = A.nt ();
@@ -364,34 +370,17 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
364370 }
365371}
366372
367- } // namespace specialization
368- } // namespace internal
369-
370- // ------------------------------------------------------------------------------
371- // / Version with target as template parameter.
372- // / @ingroup geqrf_qdwh_full_specialization
373- // /
374- template <Target target, typename scalar_t >
375- void geqrf_qdwh_full (Matrix<scalar_t >& A,
376- TriangularFactors<scalar_t >& T,
377- Options const & opts)
378- {
379- int64_t lookahead = get_option<int64_t >( opts, Option::Lookahead, 1 );
380-
381- int64_t ib = get_option<int64_t >( opts, Option::InnerBlocking, 16 );
382-
383- int64_t max_panel_threads = std::max (omp_get_max_threads ()/2 , 1 );
384- max_panel_threads = get_option<int64_t >( opts, Option::MaxPanelThreads, max_panel_threads );
385-
386- internal::specialization::geqrf_qdwh_full (internal::TargetType<target>(),
387- A, T,
388- ib, max_panel_threads, lookahead);
389- }
373+ } // namespace impl
390374
391375// ------------------------------------------------------------------------------
392- // / Distributed parallel QR factorization.
376+ // / Distributed parallel customized QR factorization.
377+ // / Required for the QR-based iterations in the polar decomposition QDWH.
393378// /
394- // / Computes a QR factorization of an m-by-n matrix $A$.
379+ // / Computes a QR factorization of m-by-n matrix $A$, m \ge 2n,
380+ // / and takes advantage of the trailing identity matrix structure.
381+ // / A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
382+ // / [ A1 ] identity matrix (n-by-n)
383+ // / Avoids doing computaions on the zero tiles below the diagonal of $A1$.
395384// / The factorization has the form
396385// / \[
397386// / A = QR,
@@ -401,15 +390,16 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
401390// /
402391// / Complexity (in real):
403392// / - for $m \ge n$, $\approx 2 m n^{2} - \frac{2}{3} n^{3}$ flops;
404- // / - for $m \le n$, $\approx 2 m^{2} n - \frac{2}{3} m^{3}$ flops;
405- // / - for $m = n$, $\approx \frac{4}{3} n^{3}$ flops.
406393// / .
407394// ------------------------------------------------------------------------------
408395// / @tparam scalar_t
409396// / One of float, double, std::complex<float>, std::complex<double>.
410397// ------------------------------------------------------------------------------
411398// / @param[in,out] A
412- // / On entry, the m-by-n matrix $A$.
399+ // / On entry, the m-by-n matrix $A$, m \ge 2n,
400+ // / A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
401+ // / [ A1 ] identity matrix (n-by-n)
402+ // /
413403// / On exit, the elements on and above the diagonal of the array contain
414404// / the min(m,n)-by-n upper trapezoidal matrix $R$ (upper triangular
415405// / if m >= n); the elements below the diagonal represent the unitary
@@ -437,7 +427,8 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
437427// / @ingroup geqrf_computational
438428// /
439429template <typename scalar_t >
440- void geqrf_qdwh_full (Matrix<scalar_t >& A,
430+ void geqrf_qdwh_full (
431+ Matrix<scalar_t >& A,
441432 TriangularFactors<scalar_t >& T,
442433 Options const & opts)
443434{
@@ -446,16 +437,19 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
446437 switch (target) {
447438 case Target::Host:
448439 case Target::HostTask:
449- geqrf_qdwh_full<Target::HostTask>(A, T, opts);
440+ impl:: geqrf_qdwh_full<Target::HostTask>( A, T, opts );
450441 break ;
442+
451443 case Target::HostNest:
452- geqrf_qdwh_full<Target::HostNest>(A, T, opts);
444+ impl:: geqrf_qdwh_full<Target::HostNest>( A, T, opts );
453445 break ;
446+
454447 case Target::HostBatch:
455- geqrf_qdwh_full<Target::HostBatch>(A, T, opts);
448+ impl:: geqrf_qdwh_full<Target::HostBatch>( A, T, opts );
456449 break ;
450+
457451 case Target::Devices:
458- geqrf_qdwh_full<Target::Devices>(A, T, opts);
452+ impl:: geqrf_qdwh_full<Target::Devices>( A, T, opts );
459453 break ;
460454 }
461455 // todo: return value for errors?
0 commit comments