Skip to content

Commit 4f24415

Browse files
committed
MPI multi-process compatibility
1 parent 38ad956 commit 4f24415

File tree

10 files changed

+370
-196
lines changed

10 files changed

+370
-196
lines changed

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ ESolver_KS_LCAO_TDDFT<Device>::ESolver_KS_LCAO_TDDFT()
4343
{
4444
classname = "ESolver_KS_LCAO_TDDFT";
4545
basisname = "LCAO";
46+
47+
// If the device is GPU, we must open use_tensor and use_lapack
48+
ct::DeviceType ct_device_type = ct::DeviceTypeToEnum<Device>::value;
49+
if (ct_device_type == ct::DeviceType::GpuDevice)
50+
{
51+
use_tensor = true;
52+
use_lapack = true;
53+
}
4654
}
4755

4856
template <typename Device>
@@ -217,22 +225,26 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell, const int istep,
217225

218226
if (td_htype == 1)
219227
{
228+
const int len_HS = use_tensor && use_lapack ? nlocal * nlocal : nloc;
229+
220230
if (this->Hk_laststep == nullptr)
221231
{
222232
this->Hk_laststep = new std::complex<double>*[kv.get_nks()];
223233
for (int ik = 0; ik < kv.get_nks(); ++ik)
224234
{
225-
this->Hk_laststep[ik] = new std::complex<double>[nloc];
226-
ModuleBase::GlobalFunc::ZEROS(Hk_laststep[ik], nloc);
235+
// Allocate memory for Hk_laststep, if (use_tensor && use_lapack), should be global
236+
this->Hk_laststep[ik] = new std::complex<double>[len_HS];
237+
ModuleBase::GlobalFunc::ZEROS(Hk_laststep[ik], len_HS);
227238
}
228239
}
229240
if (this->Sk_laststep == nullptr)
230241
{
231242
this->Sk_laststep = new std::complex<double>*[kv.get_nks()];
232243
for (int ik = 0; ik < kv.get_nks(); ++ik)
233244
{
234-
this->Sk_laststep[ik] = new std::complex<double>[nloc];
235-
ModuleBase::GlobalFunc::ZEROS(Sk_laststep[ik], nloc);
245+
// Allocate memory for Sk_laststep, if (use_tensor && use_lapack), should be global
246+
this->Sk_laststep[ik] = new std::complex<double>[len_HS];
247+
ModuleBase::GlobalFunc::ZEROS(Sk_laststep[ik], len_HS);
236248
}
237249
}
238250
}
@@ -253,8 +265,31 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell, const int istep,
253265
this->p_hamilt->updateHk(ik);
254266
hamilt::MatrixBlock<complex<double>> h_mat, s_mat;
255267
this->p_hamilt->matrix(h_mat, s_mat);
256-
BlasConnector::copy(nloc, h_mat.p, 1, Hk_laststep[ik], 1);
257-
BlasConnector::copy(nloc, s_mat.p, 1, Sk_laststep[ik], 1);
268+
269+
if (use_tensor && use_lapack)
270+
{
271+
// Gather H and S matrices to root process
272+
#ifdef __MPI
273+
int myid, num_procs;
274+
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
275+
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
276+
277+
Matrix_g<std::complex<double>> h_mat_g, s_mat_g; // Global matrix structure
278+
279+
// Collect H matrix
280+
gatherMatrix(myid, 0, h_mat, h_mat_g);
281+
BlasConnector::copy(nlocal * nlocal, h_mat_g.p.get(), 1, Hk_laststep[ik], 1);
282+
283+
// Collect S matrix
284+
gatherMatrix(myid, 0, s_mat, s_mat_g);
285+
BlasConnector::copy(nlocal * nlocal, s_mat_g.p.get(), 1, Sk_laststep[ik], 1);
286+
#endif
287+
}
288+
else
289+
{
290+
BlasConnector::copy(nloc, h_mat.p, 1, Hk_laststep[ik], 1);
291+
BlasConnector::copy(nloc, s_mat.p, 1, Sk_laststep[ik], 1);
292+
}
258293
}
259294
}
260295

source/module_esolver/esolver_ks_lcao_tddft.h

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,51 @@
22
#define ESOLVER_KS_LCAO_TDDFT_H
33
#include "esolver_ks.h"
44
#include "esolver_ks_lcao.h"
5+
#include "module_base/scalapack_connector.h" // Cpxgemr2d
56
#include "module_hamilt_lcao/hamilt_lcaodft/record_adj.h"
67
#include "module_psi/psi.h"
78

89
namespace ModuleESolver
910
{
11+
//------------------------ MPI gathering and distributing functions ------------------------//
12+
// This struct is used for collecting matrices from all processes to root process
13+
template <typename T>
14+
struct Matrix_g
15+
{
16+
std::shared_ptr<T> p;
17+
size_t row;
18+
size_t col;
19+
std::shared_ptr<int> desc;
20+
};
21+
22+
// Collect matrices from all processes to root process
23+
template <typename T>
24+
void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock<T>& mat_l, Matrix_g<T>& mat_g)
25+
{
26+
const int* desca = mat_l.desc; // Obtain the descriptor of the local matrix
27+
int ctxt = desca[1]; // BLACS context
28+
int nrows = desca[2]; // Global matrix row number
29+
int ncols = desca[3]; // Global matrix column number
30+
31+
if (myid == root_proc)
32+
{
33+
mat_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
34+
}
35+
else
36+
{
37+
mat_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
38+
}
39+
40+
// Set the descriptor of the global matrix
41+
mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
42+
mat_g.row = nrows;
43+
mat_g.col = ncols;
44+
45+
// Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
46+
Cpxgemr2d(nrows, ncols, mat_l.p, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
47+
}
48+
49+
//------------------------ MPI gathering and distributing functions ------------------------//
1050

1151
template <typename Device = base_device::DEVICE_CPU>
1252
class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, double>
@@ -38,10 +78,9 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
3878

3979
const int td_htype = 1;
4080

41-
// const bool use_tensor = true;
42-
const bool use_tensor = false;
43-
const bool use_lapack = true;
44-
// const bool use_lapack = false;
81+
//! Control heterogeneous computing of the TDDFT solver
82+
bool use_tensor = false;
83+
bool use_lapack = false;
4584

4685
private:
4786
void weight_dm_rho();

source/module_hamilt_lcao/module_tddft/bandenergy.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
290290
const ct::Tensor& psi_k,
291291
ct::Tensor& ekb)
292292
{
293-
/// ctx is nothing but the devices used in op (Device* ctx = nullptr;),
294-
/// it controls the ops to use the corresponding device to calculate results
295-
Device* ctx = {};
296-
base_device::DEVICE_CPU* cpu_ctx = {};
297293
// ct_device_type = ct::DeviceType::CpuDevice or ct::DeviceType::GpuDevice
298294
ct::DeviceType ct_device_type = ct::DeviceTypeToEnum<Device>::value;
299295
// ct_Device = ct::DEVICE_CPU or ct::DEVICE_GPU
@@ -302,12 +298,12 @@ void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
302298
// Create Tensor objects for temporary data
303299
ct::Tensor tmp1(ct::DataType::DT_COMPLEX_DOUBLE,
304300
ct_device_type,
305-
ct::TensorShape({pv->nloc_wfc})); // tmp1 shape: nlocal * nband
301+
ct::TensorShape({nlocal * nband})); // tmp1 shape: nlocal * nband
306302
tmp1.zero();
307303

308304
ct::Tensor Eij(ct::DataType::DT_COMPLEX_DOUBLE,
309305
ct_device_type,
310-
ct::TensorShape({pv->nloc})); // Eij shape: nlocal * nlocal
306+
ct::TensorShape({nlocal * nlocal})); // Eij shape: nlocal * nlocal
311307
// Why not use nband * nband ?????
312308
Eij.zero();
313309

@@ -346,17 +342,19 @@ void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
346342

347343
if (PARAM.inp.td_print_eij >= 0.0)
348344
{
345+
ct::Tensor Eij_cpu = Eij.to_device<ct::DEVICE_CPU>();
346+
349347
GlobalV::ofs_running
350348
<< "------------------------------------------------------------------------------------------------"
351349
<< std::endl;
352350
GlobalV::ofs_running << " Eij:" << std::endl;
353-
for (int i = 0; i < pv->nrow_bands; i++)
351+
for (int i = 0; i < nband; i++)
354352
{
355-
for (int j = 0; j < pv->ncol_bands; j++)
353+
for (int j = 0; j < nband; j++)
356354
{
357355
double aa = 0.0, bb = 0.0;
358-
aa = Eij.data<std::complex<double>>()[i * pv->ncol + j].real();
359-
bb = Eij.data<std::complex<double>>()[i * pv->ncol + j].imag();
356+
aa = Eij_cpu.data<std::complex<double>>()[i * nlocal + j].real();
357+
bb = Eij_cpu.data<std::complex<double>>()[i * nlocal + j].imag();
360358
if (std::abs(aa) < PARAM.inp.td_print_eij)
361359
{
362360
aa = 0.0;
@@ -384,8 +382,6 @@ void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
384382
for (int i = 0; i < nband; ++i)
385383
{
386384
base_device::memory::synchronize_memory_op<double, Device, Device>()(
387-
ctx,
388-
ctx,
389385
ekb.data<double>() + i,
390386
reinterpret_cast<const double*>(Eij.data<std::complex<double>>() + i * nlocal + i),
391387
1);

source/module_hamilt_lcao/module_tddft/evolve_elec.cpp

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ Evolve_elec<Device>::Evolve_elec(){};
1515
template <typename Device>
1616
Evolve_elec<Device>::~Evolve_elec(){};
1717

18-
template <typename Device>
19-
Device* Evolve_elec<Device>::ctx = {};
20-
template <typename Device>
21-
base_device::DEVICE_CPU* Evolve_elec<Device>::cpu_ctx = {};
2218
template <typename Device>
2319
ct::DeviceType Evolve_elec<Device>::ct_device_type = ct::DeviceTypeToEnum<Device>::value;
2420

@@ -89,53 +85,69 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
8985
}
9086
else
9187
{
92-
// std::cout << "nband = " << nband << std::endl;
93-
// std::cout << "psi->get_nbands() = " << psi->get_nbands() << std::endl;
94-
// std::cout << "nlocal = " << nlocal << std::endl;
95-
// std::cout << "psi->get_nbasis() = " << psi->get_nbasis() << std::endl;
96-
// std::cout << "para_orb.nloc = " << para_orb.nloc << std::endl;
97-
// std::cout << "para_orb.nrow = " << para_orb.nrow << std::endl;
98-
// std::cout << "para_orb.ncol = " << para_orb.ncol << std::endl;
99-
// std::cout << "ekb.nr = " << ekb.nr << std::endl;
100-
// std::cout << "ekb.nc = " << ekb.nc << std::endl;
88+
const int len_psi_k_1 = use_lapack ? nband : psi->get_nbands();
89+
const int len_psi_k_2 = use_lapack ? nlocal : psi->get_nbasis();
90+
const int len_HS_laststep = use_lapack ? nlocal * nlocal : para_orb.nloc;
10191

10292
// Create Tensor for psi_k, psi_k_laststep, H_laststep, S_laststep, ekb
10393
ct::Tensor psi_k_tensor(ct::DataType::DT_COMPLEX_DOUBLE,
10494
ct_device_type,
105-
ct::TensorShape({psi->get_nbands(), psi->get_nbasis()}));
95+
ct::TensorShape({len_psi_k_1, len_psi_k_2}));
10696
ct::Tensor psi_k_laststep_tensor(ct::DataType::DT_COMPLEX_DOUBLE,
10797
ct_device_type,
108-
ct::TensorShape({psi->get_nbands(), psi->get_nbasis()}));
98+
ct::TensorShape({len_psi_k_1, len_psi_k_2}));
10999
ct::Tensor H_laststep_tensor(ct::DataType::DT_COMPLEX_DOUBLE,
110100
ct_device_type,
111-
ct::TensorShape({para_orb.nloc}));
101+
ct::TensorShape({len_HS_laststep}));
112102
ct::Tensor S_laststep_tensor(ct::DataType::DT_COMPLEX_DOUBLE,
113103
ct_device_type,
114-
ct::TensorShape({para_orb.nloc}));
104+
ct::TensorShape({len_HS_laststep}));
115105
ct::Tensor ekb_tensor(ct::DataType::DT_DOUBLE, ct_device_type, ct::TensorShape({nband}));
116106

117-
// Syncronize data from CPU to Device
118-
syncmem_complex_h2d_op()(ctx,
119-
cpu_ctx,
120-
psi_k_tensor.data<std::complex<double>>(),
121-
psi[0].get_pointer(),
122-
psi->get_nbands() * psi->get_nbasis());
123-
syncmem_complex_h2d_op()(ctx,
124-
cpu_ctx,
125-
psi_k_laststep_tensor.data<std::complex<double>>(),
126-
psi_laststep[0].get_pointer(),
127-
psi->get_nbands() * psi->get_nbasis());
128-
syncmem_complex_h2d_op()(ctx,
129-
cpu_ctx,
130-
H_laststep_tensor.data<std::complex<double>>(),
107+
// Global psi
108+
ModuleESolver::Matrix_g<std::complex<double>> psi_g;
109+
ModuleESolver::Matrix_g<std::complex<double>> psi_laststep_g;
110+
111+
if (use_lapack)
112+
{
113+
// Need to gather the psi to the root process on CPU
114+
// H_laststep and S_laststep are already gathered in esolver_ks_lcao_tddft.cpp
115+
#ifdef __MPI
116+
// Access the rank of the calling process in the communicator
117+
int myid, root_proc = 0;
118+
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
119+
120+
// Gather psi to the root process
121+
gatherPsi(myid, root_proc, psi[0].get_pointer(), para_orb, psi_g);
122+
gatherPsi(myid, root_proc, psi_laststep[0].get_pointer(), para_orb, psi_laststep_g);
123+
124+
// Syncronize data from CPU to Device
125+
syncmem_complex_h2d_op()(psi_k_tensor.data<std::complex<double>>(),
126+
psi_g.p.get(),
127+
len_psi_k_1 * len_psi_k_2);
128+
syncmem_complex_h2d_op()(psi_k_laststep_tensor.data<std::complex<double>>(),
129+
psi_laststep_g.p.get(),
130+
len_psi_k_1 * len_psi_k_2);
131+
#endif
132+
}
133+
else
134+
{
135+
// Syncronize data from CPU to Device
136+
syncmem_complex_h2d_op()(psi_k_tensor.data<std::complex<double>>(),
137+
psi[0].get_pointer(),
138+
len_psi_k_1 * len_psi_k_2);
139+
syncmem_complex_h2d_op()(psi_k_laststep_tensor.data<std::complex<double>>(),
140+
psi_laststep[0].get_pointer(),
141+
len_psi_k_1 * len_psi_k_2);
142+
}
143+
144+
syncmem_complex_h2d_op()(H_laststep_tensor.data<std::complex<double>>(),
131145
Hk_laststep[ik],
132-
para_orb.nloc);
133-
syncmem_complex_h2d_op()(ctx,
134-
cpu_ctx,
135-
S_laststep_tensor.data<std::complex<double>>(),
146+
len_HS_laststep);
147+
syncmem_complex_h2d_op()(S_laststep_tensor.data<std::complex<double>>(),
136148
Sk_laststep[ik],
137-
para_orb.nloc);
138-
syncmem_double_h2d_op()(ctx, cpu_ctx, ekb_tensor.data<double>(), &(ekb(ik, 0)), nband);
149+
len_HS_laststep);
150+
syncmem_double_h2d_op()(ekb_tensor.data<double>(), &(ekb(ik, 0)), nband);
139151

140152
evolve_psi_tensor<Device>(nband,
141153
nlocal,
@@ -151,28 +163,40 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
151163
print_matrix,
152164
use_lapack);
153165

154-
// Syncronize data from Device to CPU
155-
syncmem_complex_d2h_op()(cpu_ctx,
156-
ctx,
157-
psi[0].get_pointer(),
158-
psi_k_tensor.data<std::complex<double>>(),
159-
psi->get_nbands() * psi->get_nbasis());
160-
syncmem_complex_d2h_op()(cpu_ctx,
161-
ctx,
162-
psi_laststep[0].get_pointer(),
163-
psi_k_laststep_tensor.data<std::complex<double>>(),
164-
psi->get_nbands() * psi->get_nbasis());
165-
syncmem_complex_d2h_op()(cpu_ctx,
166-
ctx,
167-
Hk_laststep[ik],
166+
// Need to distribute global psi back to all processes
167+
if (use_lapack)
168+
{
169+
#ifdef __MPI
170+
// Syncronize data from Device to CPU
171+
syncmem_complex_d2h_op()(psi_g.p.get(),
172+
psi_k_tensor.data<std::complex<double>>(),
173+
len_psi_k_1 * len_psi_k_2);
174+
syncmem_complex_d2h_op()(psi_laststep_g.p.get(),
175+
psi_k_laststep_tensor.data<std::complex<double>>(),
176+
len_psi_k_1 * len_psi_k_2);
177+
178+
// Distribute psi to all processes
179+
distributePsi(para_orb, psi[0].get_pointer(), psi_g);
180+
distributePsi(para_orb, psi_laststep[0].get_pointer(), psi_laststep_g);
181+
#endif
182+
}
183+
else
184+
{
185+
// Syncronize data from Device to CPU
186+
syncmem_complex_d2h_op()(psi[0].get_pointer(),
187+
psi_k_tensor.data<std::complex<double>>(),
188+
len_psi_k_1 * len_psi_k_2);
189+
syncmem_complex_d2h_op()(psi_laststep[0].get_pointer(),
190+
psi_k_laststep_tensor.data<std::complex<double>>(),
191+
len_psi_k_1 * len_psi_k_2);
192+
}
193+
syncmem_complex_d2h_op()(Hk_laststep[ik],
168194
H_laststep_tensor.data<std::complex<double>>(),
169-
para_orb.nloc);
170-
syncmem_complex_d2h_op()(cpu_ctx,
171-
ctx,
172-
Sk_laststep[ik],
195+
len_HS_laststep);
196+
syncmem_complex_d2h_op()(Sk_laststep[ik],
173197
S_laststep_tensor.data<std::complex<double>>(),
174-
para_orb.nloc);
175-
syncmem_double_d2h_op()(cpu_ctx, ctx, &(ekb(ik, 0)), ekb_tensor.data<double>(), nband);
198+
len_HS_laststep);
199+
syncmem_double_d2h_op()(&(ekb(ik, 0)), ekb_tensor.data<double>(), nband);
176200

177201
// std::cout << "Print ekb tensor: " << std::endl;
178202
// ekb.print(std::cout);

0 commit comments

Comments
 (0)