Skip to content

Commit 782dcc9

Browse files
committed
Make the extra memory usage DSP-hardware-specialized. Add some annotations.
1 parent 41abbe7 commit 782dcc9

File tree

1 file changed

+62
-29
lines changed

1 file changed

+62
-29
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
182182
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);
183183

184184
#ifdef __DSP
185-
gemm_op_mt<T, Device>()
185+
gemm_op_mt<T, Device>() // In order to not coding another whole template, using this method to minimize the code change.
186186
#else
187187
gemm_op<T, Device>()
188188
#endif
@@ -444,6 +444,9 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
444444
#ifdef __MPI
445445
if (this->diag_comm.nproc > 1)
446446
{
447+
#ifdef __DSP
448+
// Only on dsp hardware need an extra space to reduce data
449+
447450
auto* swap = new T[notconv * this->nbase_x];
448451
auto* target = new T[notconv * this->nbase_x];
449452

@@ -458,13 +461,6 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
458461
{
459462
if (base_device::get_current_precision(swap) == "single")
460463
{
461-
// MPI_Reduce(swap,
462-
// hcc + nbase * this->nbase_x,
463-
// notconv * this->nbase_x,
464-
// MPI_COMPLEX,
465-
// MPI_SUM,
466-
// 0,
467-
// this->diag_comm.comm);
468464
MPI_Reduce(swap,
469465
target,
470466
notconv * this->nbase_x,
@@ -475,13 +471,6 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
475471
}
476472
else
477473
{
478-
// MPI_Reduce(swap,
479-
// hcc + nbase * this->nbase_x,
480-
// notconv * this->nbase_x,
481-
// MPI_DOUBLE_COMPLEX,
482-
// MPI_SUM,
483-
// 0,
484-
// this->diag_comm.comm);
485474
MPI_Reduce(swap,
486475
target,
487476
notconv * this->nbase_x,
@@ -496,13 +485,6 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
496485

497486
if (base_device::get_current_precision(swap) == "single")
498487
{
499-
// MPI_Reduce(swap,
500-
// scc + nbase * this->nbase_x,
501-
// notconv * this->nbase_x,
502-
// MPI_COMPLEX,
503-
// MPI_SUM,
504-
// 0,
505-
// this->diag_comm.comm);
506488
MPI_Reduce(swap,
507489
target,
508490
notconv * this->nbase_x,
@@ -513,13 +495,6 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
513495
}
514496
else
515497
{
516-
// MPI_Reduce(swap,
517-
// scc + nbase * this->nbase_x,
518-
// notconv * this->nbase_x,
519-
// MPI_DOUBLE_COMPLEX,
520-
// MPI_SUM,
521-
// 0,
522-
// this->diag_comm.comm);
523498
MPI_Reduce(swap,
524499
target,
525500
notconv * this->nbase_x,
@@ -532,6 +507,64 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
532507
syncmem_complex_op()(this->ctx, this->ctx, scc + nbase * this->nbase_x, target, notconv * this->nbase_x);
533508
delete[] swap;
534509
delete[] target;
510+
#else
511+
auto* swap = new T[notconv * this->nbase_x];
512+
513+
syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);
514+
515+
if (std::is_same<T, double>::value)
516+
{
517+
Parallel_Reduce::reduce_pool(hcc + nbase * this->nbase_x, notconv * this->nbase_x);
518+
Parallel_Reduce::reduce_pool(scc + nbase * this->nbase_x, notconv * this->nbase_x);
519+
}
520+
else
521+
{
522+
if (base_device::get_current_precision(swap) == "single")
523+
{
524+
MPI_Reduce(swap,
525+
hcc + nbase * this->nbase_x,
526+
notconv * this->nbase_x,
527+
MPI_COMPLEX,
528+
MPI_SUM,
529+
0,
530+
this->diag_comm.comm);
531+
}
532+
else
533+
{
534+
MPI_Reduce(swap,
535+
hcc + nbase * this->nbase_x,
536+
notconv * this->nbase_x,
537+
MPI_DOUBLE_COMPLEX,
538+
MPI_SUM,
539+
0,
540+
this->diag_comm.comm);
541+
}
542+
543+
syncmem_complex_op()(this->ctx, this->ctx, swap, scc + nbase * this->nbase_x, notconv * this->nbase_x);
544+
545+
if (base_device::get_current_precision(swap) == "single")
546+
{
547+
MPI_Reduce(swap,
548+
scc + nbase * this->nbase_x,
549+
notconv * this->nbase_x,
550+
MPI_COMPLEX,
551+
MPI_SUM,
552+
0,
553+
this->diag_comm.comm);
554+
}
555+
else
556+
{
557+
MPI_Reduce(swap,
558+
scc + nbase * this->nbase_x,
559+
notconv * this->nbase_x,
560+
MPI_DOUBLE_COMPLEX,
561+
MPI_SUM,
562+
0,
563+
this->diag_comm.comm);
564+
}
565+
}
566+
delete[] swap;
567+
#endif
535568
}
536569
#endif
537570

0 commit comments

Comments
 (0)