Skip to content

Commit 6c98313

Browse files
committed
Reorganize dsp codes
1 parent 82991ca commit 6c98313

File tree

2 files changed

+70
-61
lines changed

2 files changed

+70
-61
lines changed

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#define DSP_CONNECTOR_H
33
#ifdef __DSP
44

5+
#include "module_base/module_device/device.h"
6+
#include "module_hsolver/diag_comm_info.h"
7+
58
// Base dsp functions
69
void dspInitHandle(int id);
710
void dspDestoryHandle(int id);
@@ -62,5 +65,70 @@ void cgemm_mth_(const char *transa, const char *transb,
6265

6366
//#define zgemm_ zgemm_mt
6467

68+
// The next is dsp utils. It may be moved to other files if this file get too huge
69+
70+
Device* ctx = {};
71+
base_device::DEVICE_CPU* cpu_ctx = {};
72+
base_device::AbacusDevice_t device = {};
73+
74+
template <typename T>
75+
void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase_x, int notconv, const diag_comm_info diag_comm){
76+
77+
using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
78+
79+
auto* swap = new T[notconv * nbase_x];
80+
auto* target = new T[notconv * nbase_x];
81+
syncmem_complex_op()(ctx, ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
82+
if (base_device::get_current_precision(swap) == "single")
83+
{
84+
MPI_Reduce(swap,
85+
target,
86+
notconv * nbase_x,
87+
MPI_COMPLEX,
88+
MPI_SUM,
89+
0,
90+
diag_comm.comm);
91+
}
92+
else
93+
{
94+
MPI_Reduce(swap,
95+
target,
96+
notconv * nbase_x,
97+
MPI_DOUBLE_COMPLEX,
98+
MPI_SUM,
99+
0,
100+
diag_comm.comm);
101+
}
102+
103+
syncmem_complex_op()(ctx, ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
104+
syncmem_complex_op()(ctx, ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);
105+
106+
if (base_device::get_current_precision(swap) == "single")
107+
{
108+
MPI_Reduce(swap,
109+
target,
110+
notconv * nbase_x,
111+
MPI_COMPLEX,
112+
MPI_SUM,
113+
0,
114+
diag_comm.comm);
115+
}
116+
else
117+
{
118+
MPI_Reduce(swap,
119+
target,
120+
notconv * nbase_x,
121+
MPI_DOUBLE_COMPLEX,
122+
MPI_SUM,
123+
0,
124+
diag_comm.comm);
125+
}
126+
127+
syncmem_complex_op()(ctx, ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
128+
delete[] swap;
129+
delete[] target;
130+
}
131+
132+
65133
#endif
66134
#endif

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "module_base/timer.h"
77
#include "module_hsolver/kernels/dngvd_op.h"
88
#include "module_hsolver/kernels/math_kernel_op.h"
9+
#include "module_base/kernels/dsp/dsp_connector.h"
910

1011
#include <vector>
1112

@@ -446,67 +447,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
446447
{
447448
#ifdef __DSP
448449
// Only on dsp hardware need an extra space to reduce data
449-
450-
auto* swap = new T[notconv * this->nbase_x];
451-
auto* target = new T[notconv * this->nbase_x];
452-
453-
syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);
454-
455-
if (std::is_same<T, double>::value)
456-
{
457-
Parallel_Reduce::reduce_pool(hcc + nbase * this->nbase_x, notconv * this->nbase_x);
458-
Parallel_Reduce::reduce_pool(scc + nbase * this->nbase_x, notconv * this->nbase_x);
459-
}
460-
else
461-
{
462-
if (base_device::get_current_precision(swap) == "single")
463-
{
464-
MPI_Reduce(swap,
465-
target,
466-
notconv * this->nbase_x,
467-
MPI_COMPLEX,
468-
MPI_SUM,
469-
0,
470-
this->diag_comm.comm);
471-
}
472-
else
473-
{
474-
MPI_Reduce(swap,
475-
target,
476-
notconv * this->nbase_x,
477-
MPI_DOUBLE_COMPLEX,
478-
MPI_SUM,
479-
0,
480-
this->diag_comm.comm);
481-
}
482-
483-
syncmem_complex_op()(this->ctx, this->ctx, hcc + nbase * this->nbase_x, target, notconv * this->nbase_x);
484-
syncmem_complex_op()(this->ctx, this->ctx, swap, scc + nbase * this->nbase_x, notconv * this->nbase_x);
485-
486-
if (base_device::get_current_precision(swap) == "single")
487-
{
488-
MPI_Reduce(swap,
489-
target,
490-
notconv * this->nbase_x,
491-
MPI_COMPLEX,
492-
MPI_SUM,
493-
0,
494-
this->diag_comm.comm);
495-
}
496-
else
497-
{
498-
MPI_Reduce(swap,
499-
target,
500-
notconv * this->nbase_x,
501-
MPI_DOUBLE_COMPLEX,
502-
MPI_SUM,
503-
0,
504-
this->diag_comm.comm);
505-
}
506-
}
507-
syncmem_complex_op()(this->ctx, this->ctx, scc + nbase * this->nbase_x, target, notconv * this->nbase_x);
508-
delete[] swap;
509-
delete[] target;
450+
dsp_dav_subspace_reduce(hcc, scc, this->nbase_x, this->notconv, this->diag_comm);
510451
#else
511452
auto* swap = new T[notconv * this->nbase_x];
512453

0 commit comments

Comments
 (0)