Skip to content

Commit ddf990f

Browse files
Feature: Optimized memory management on DSP (#5361)
* Initial commit * Change memory_op construction * I finally find this * Fix template bug * Fix memory header definition * Optimize memory op usage * Update diago_subspace * No change * Fix MPI Error * Make the extra memory usage DSP-hardware-specialized. Add some annotations. * Reorganize dsp codes * Fix bug 1 * Fix bug 2 * Finish transporting codes --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 535c7f9 commit ddf990f

File tree

8 files changed

+177
-7
lines changed

8 files changed

+177
-7
lines changed

source/module_base/blas_connector.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
9393
}
9494
#ifdef __DSP
9595
else if (device_type == base_device::AbacusDevice_t::DspDevice){
96-
sgemm_mt_(&transb, &transa, &n, &m, &k,
96+
sgemm_mth_(&transb, &transa, &n, &m, &k,
9797
&alpha, b, &ldb, a, &lda,
9898
&beta, c, &ldc, GlobalV::MY_RANK);
9999
}
@@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
111111
}
112112
#ifdef __DSP
113113
else if (device_type == base_device::AbacusDevice_t::DspDevice){
114-
dgemm_mt_(&transb, &transa, &n, &m, &k,
114+
dgemm_mth_(&transb, &transa, &n, &m, &k,
115115
&alpha, b, &ldb, a, &lda,
116116
&beta, c, &ldc, GlobalV::MY_RANK);
117117
}
@@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
129129
}
130130
#ifdef __DSP
131131
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
132-
cgemm_mt_(&transb, &transa, &n, &m, &k,
132+
cgemm_mth_(&transb, &transa, &n, &m, &k,
133133
&alpha, b, &ldb, a, &lda,
134134
&beta, c, &ldc, GlobalV::MY_RANK);
135135
}
@@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
147147
}
148148
#ifdef __DSP
149149
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
150-
zgemm_mt_(&transb, &transa, &n, &m, &k,
150+
zgemm_mth_(&transb, &transa, &n, &m, &k,
151151
&alpha, b, &ldb, a, &lda,
152152
&beta, c, &ldc, GlobalV::MY_RANK);
153153
}

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#define DSP_CONNECTOR_H
33
#ifdef __DSP
44

5+
#include "module_base/module_device/device.h"
6+
#include "module_base/module_device/memory_op.h"
7+
#include "module_hsolver/diag_comm_info.h"
8+
59
// Base dsp functions
610
void dspInitHandle(int id);
711
void dspDestoryHandle(int id);
@@ -62,5 +66,66 @@ void cgemm_mth_(const char *transa, const char *transb,
6266

6367
//#define zgemm_ zgemm_mt
6468

69+
// The next is dsp utils. It may be moved to other files if this file get too huge
70+
71+
template <typename T>
72+
void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){
73+
74+
using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;
75+
76+
auto* swap = new T[notconv * nbase_x];
77+
auto* target = new T[notconv * nbase_x];
78+
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, hcc + nbase * nbase_x, notconv * nbase_x);
79+
if (base_device::get_current_precision(swap) == "single")
80+
{
81+
MPI_Reduce(swap,
82+
target,
83+
notconv * nbase_x,
84+
MPI_COMPLEX,
85+
MPI_SUM,
86+
0,
87+
diag_comm);
88+
}
89+
else
90+
{
91+
MPI_Reduce(swap,
92+
target,
93+
notconv * nbase_x,
94+
MPI_DOUBLE_COMPLEX,
95+
MPI_SUM,
96+
0,
97+
diag_comm);
98+
}
99+
100+
syncmem_complex_op()(cpu_ctx, cpu_ctx, hcc + nbase * nbase_x, target, notconv * nbase_x);
101+
syncmem_complex_op()(cpu_ctx, cpu_ctx, swap, scc + nbase * nbase_x, notconv * nbase_x);
102+
103+
if (base_device::get_current_precision(swap) == "single")
104+
{
105+
MPI_Reduce(swap,
106+
target,
107+
notconv * nbase_x,
108+
MPI_COMPLEX,
109+
MPI_SUM,
110+
0,
111+
diag_comm);
112+
}
113+
else
114+
{
115+
MPI_Reduce(swap,
116+
target,
117+
notconv * nbase_x,
118+
MPI_DOUBLE_COMPLEX,
119+
MPI_SUM,
120+
0,
121+
diag_comm);
122+
}
123+
124+
syncmem_complex_op()(cpu_ctx, cpu_ctx, scc + nbase * nbase_x, target, notconv * nbase_x);
125+
delete[] swap;
126+
delete[] target;
127+
}
128+
129+
65130
#endif
66131
#endif

source/module_base/module_device/memory_op.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,5 +346,57 @@ template struct delete_memory_op<std::complex<float>, base_device::DEVICE_GPU>;
346346
template struct delete_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
347347
#endif
348348

349+
#ifdef __DSP
350+
351+
template <typename FPTYPE>
352+
struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
353+
{
354+
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE*& arr, const size_t size, const char* record_in)
355+
{
356+
if (arr != nullptr)
357+
{
358+
free_ht(arr);
359+
}
360+
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK);
361+
std::string record_string;
362+
if (record_in != nullptr)
363+
{
364+
record_string = record_in;
365+
}
366+
else
367+
{
368+
record_string = "no_record";
369+
}
370+
371+
if (record_string != "no_record")
372+
{
373+
ModuleBase::Memory::record(record_string, sizeof(FPTYPE) * size);
374+
}
375+
}
376+
};
377+
378+
template <typename FPTYPE>
379+
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
380+
{
381+
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
382+
{
383+
free_ht(arr);
384+
}
385+
};
386+
387+
388+
template struct resize_memory_op_mt<int, base_device::DEVICE_CPU>;
389+
template struct resize_memory_op_mt<float, base_device::DEVICE_CPU>;
390+
template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
391+
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
392+
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
393+
394+
template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
395+
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
396+
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
397+
template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
398+
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
399+
#endif
400+
349401
} // namespace memory
350402
} // namespace base_device

source/module_base/module_device/memory_op.h

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,36 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_GPU>
146146
};
147147
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
148148

149+
#ifdef __DSP
150+
151+
template <typename FPTYPE, typename Device>
152+
struct resize_memory_op_mt
153+
{
154+
/// @brief Allocate memory for a given pointer. Note this op will free the pointer first.
155+
///
156+
/// Input Parameters
157+
/// \param dev : the type of computing device
158+
/// \param size : array size
159+
/// \param record_string : label for memory record
160+
///
161+
/// Output Parameters
162+
/// \param arr : allocated array
163+
void operator()(const Device* dev, FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
164+
};
165+
166+
template <typename FPTYPE, typename Device>
167+
struct delete_memory_op_mt
168+
{
169+
/// @brief free memory for multi-device
170+
///
171+
/// Input Parameters
172+
/// \param dev : the type of computing device
173+
/// \param arr : the input array
174+
void operator()(const Device* dev, FPTYPE* arr);
175+
};
176+
177+
#endif // __DSP
178+
149179
} // end of namespace memory
150180
} // end of namespace base_device
151181

@@ -233,5 +263,4 @@ using castmem_z2c_d2h_op = base_device::memory::
233263

234264
static base_device::DEVICE_CPU* cpu_ctx = {};
235265
static base_device::DEVICE_GPU* gpu_ctx = {};
236-
237266
#endif // MODULE_DEVICE_MEMORY_H_

source/module_base/module_device/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace base_device
66

77
struct DEVICE_CPU;
88
struct DEVICE_GPU;
9+
struct DEVICE_DSP;
910

1011
enum AbacusDevice_t
1112
{

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "module_base/timer.h"
77
#include "module_hsolver/kernels/dngvd_op.h"
88
#include "module_hsolver/kernels/math_kernel_op.h"
9+
#include "module_base/kernels/dsp/dsp_connector.h"
910

1011
#include <vector>
1112

@@ -182,7 +183,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
182183
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);
183184

184185
#ifdef __DSP
185-
gemm_op_mt<T, Device>()
186+
gemm_op_mt<T, Device>() // In order to not coding another whole template, using this method to minimize the code change.
186187
#else
187188
gemm_op<T, Device>()
188189
#endif
@@ -444,7 +445,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
444445
#ifdef __MPI
445446
if (this->diag_comm.nproc > 1)
446447
{
448+
#ifdef __DSP
449+
// Only on dsp hardware need an extra space to reduce data
450+
dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm);
451+
#else
447452
auto* swap = new T[notconv * this->nbase_x];
453+
448454
syncmem_complex_op()(this->ctx, this->ctx, swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x);
449455

450456
if (std::is_same<T, double>::value)
@@ -499,6 +505,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
499505
}
500506
}
501507
delete[] swap;
508+
#endif
502509
}
503510
#endif
504511

source/module_hsolver/diago_dav_subspace.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,22 @@ class Diago_DavSubspace
139139

140140
bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf);
141141

142+
#ifdef __DSP
143+
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
144+
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
145+
#else
142146
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
143147
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
148+
#endif
144149
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
145150

151+
#ifdef __DSP
152+
using resmem_real_op = base_device::memory::resize_memory_op_mt<Real, Device>;
153+
using delmem_real_op = base_device::memory::delete_memory_op_mt<Real, Device>;
154+
#else
146155
using resmem_real_op = base_device::memory::resize_memory_op<Real, Device>;
147156
using delmem_real_op = base_device::memory::delete_memory_op<Real, Device>;
157+
#endif
148158
using setmem_real_op = base_device::memory::set_memory_op<Real, Device>;
149159

150160
using resmem_real_h_op = base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>;

source/module_psi/psi.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,16 @@ class Psi
143143

144144
bool allocate_inside = true; ///<whether allocate psi inside Psi class
145145

146-
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
146+
#ifdef __DSP
147+
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
148+
using resize_memory_op = base_device::memory::resize_memory_op_mt<T, Device>;
149+
#else
147150
using delete_memory_op = base_device::memory::delete_memory_op<T, Device>;
148151
using resize_memory_op = base_device::memory::resize_memory_op<T, Device>;
152+
#endif
153+
using set_memory_op = base_device::memory::set_memory_op<T, Device>;
149154
using synchronize_memory_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
155+
150156
};
151157

152158
} // end of namespace psi

0 commit comments

Comments
 (0)