Skip to content

Commit ede3f95

Browse files
committed
fix bug in GPU-BPCG
1 parent 6970f4c commit ede3f95

File tree

9 files changed

+154
-70
lines changed

9 files changed

+154
-70
lines changed

source/module_base/para_gemm.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,17 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
153153
const Device* ctx = {};
154154

155155
std::vector<T> B_tmp(max_colA * LDA);
156+
std::vector<T> isend_tmp;
157+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
158+
{
159+
isend_tmp.resize(max_colA * LDA);
160+
}
156161
for (int ip = 0; ip < col_nproc; ip++)
157162
{
158163
if (col_rank != ip)
159164
{
160165
int size = ncolA * LDA;
161-
Parallel_Common::isend_dev<T, Device>(A, size, ip, 0, col_world, &requests[ip], B_tmp.data());
166+
Parallel_Common::isend_dev<T, Device>(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data());
162167
}
163168
}
164169

source/module_base/parallel_device.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ void isend_data(const std::complex<float>* buf, int count, int dest, int tag, MP
1818
{
1919
MPI_Isend(buf, count, MPI_COMPLEX, dest, tag, comm, request);
2020
}
21+
void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm)
22+
{
23+
MPI_Send(buf, count, MPI_DOUBLE, dest, tag, comm);
24+
}
25+
void send_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm)
26+
{
27+
MPI_Send(buf, count, MPI_DOUBLE_COMPLEX, dest, tag, comm);
28+
}
29+
void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm)
30+
{
31+
MPI_Send(buf, count, MPI_FLOAT, dest, tag, comm);
32+
}
33+
void send_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm)
34+
{
35+
MPI_Send(buf, count, MPI_COMPLEX, dest, tag, comm);
36+
}
2137
void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status)
2238
{
2339
MPI_Recv(buf, count, MPI_DOUBLE, source, tag, comm, status);

source/module_base/parallel_device.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm,
1111
void isend_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
1212
void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
1313
void isend_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
14+
void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm);
15+
void send_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm);
16+
void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm);
17+
void send_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm);
1418
void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
1519
void recv_data(std::complex<double>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
1620
void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
@@ -39,15 +43,31 @@ struct object_cpu_point
3943
};
4044

4145
/**
42-
* @brief isend data in Device
46+
* @brief send data in Device
4347
*
4448
*/
4549
template <typename T, typename Device>
46-
void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* tmp_space = nullptr)
50+
void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr)
4751
{
4852
object_cpu_point<T,Device> o;
4953
T* object_cpu = o.get(object, count, tmp_space);
5054
o.sync_d2h(object_cpu, object, count);
55+
send_data(object_cpu, count, dest, tag, comm);
56+
o.del(object_cpu);
57+
return;
58+
}
59+
60+
/**
61+
* @brief isend data in Device
62+
* @note before the date in send_space is recieved, it should not be modified
63+
*
64+
*/
65+
template <typename T, typename Device>
66+
void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space)
67+
{
68+
object_cpu_point<T,Device> o;
69+
T* object_cpu = o.get(object, count, send_space);
70+
o.sync_d2h(object_cpu, object, count);
5171
isend_data(object_cpu, count, dest, tag, comm, request);
5272
o.del(object_cpu);
5373
return;

source/module_hsolver/para_linear_transform.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
6060
{
6161
std::vector<MPI_Request> requests(nproc_col);
6262
std::vector<T> A_tmp(max_colA * LDA);
63+
std::vector<T> isend_tmp;
6364
T* A_tmp_device = A_tmp.data();
6465
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
6566
{
6667
A_tmp_device = nullptr;
68+
isend_tmp.resize(max_colA * LDA);
6769
resmem_dev_op()(A_tmp_device, max_colA * LDA);
6870
}
6971
T* B_tmp = nullptr;
@@ -80,7 +82,7 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
8082
if (rank_col != ip)
8183
{
8284
int size = LDA * ncolA;
83-
Parallel_Common::isend_dev<T, Device>(A, size, ip, 0, col_world, &requests[ip], A_tmp.data());
85+
Parallel_Common::isend_dev<T, Device>(A, size, ip, 0, col_world, &requests[ip], isend_tmp.data());
8486
}
8587
}
8688

source/module_io/read_input_item_system.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void ReadInput::item_system()
257257
"will be distributed among each group";
258258
read_sync_int(input.bndpar);
259259
item.reset_value = [](const Input_Item& item, Parameter& para) {
260-
if (para.input.esolver_type != "sdft")
260+
if (para.input.esolver_type != "sdft" && para.input.ks_solver != "bpcg")
261261
{
262262
para.input.bndpar = 1;
263263
}

source/module_psi/psi_init.cpp

Lines changed: 95 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "module_base/macros.h"
44
#include "module_base/memory.h"
5+
#include "module_base/parallel_device.h"
56
#include "module_base/timer.h"
67
#include "module_base/tool_quit.h"
78
#include "module_hsolver/diago_iter_assist.h"
@@ -40,8 +41,8 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
4041
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
4142
ModuleBase::timer::tick("PSIInit", "prepare_init");
4243
this->psi_initer.reset();
43-
if (this->init_wfc == "random" || (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1))
44-
{ //temporary solution for band parallel bpcg
44+
if (this->init_wfc == "random")
45+
{ // temporary solution for band parallel bpcg
4546
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
4647
}
4748
else if (this->init_wfc == "file")
@@ -97,30 +98,34 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9798
ModuleBase::timer::tick("PSIInit", "initialize_psi");
9899

99100
const int nbands_start = this->psi_initer->nbands_start();
100-
const int nbands = psi->get_nbands();
101+
const int nbands_l = psi->get_nbands();
101102
const int nbasis = psi->get_nbasis();
102-
const bool not_equal = (nbands_start != nbands);
103+
const bool not_equal = (nbands_start != nbands_l);
103104

104105
Psi<T>* psi_cpu = reinterpret_cast<psi::Psi<T>*>(psi);
105106
Psi<T, Device>* psi_device = kspw_psi;
106107

107-
if (not_equal)
108+
bool fill = PARAM.inp.ks_solver != "bpcg" || GlobalV::MY_BNDGROUP == 0;
109+
if (fill)
108110
{
109-
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nbasis, true);
110-
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
111-
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
112-
}
113-
else if (PARAM.inp.precision == "single")
114-
{
115-
if (PARAM.inp.device == "cpu")
111+
if (not_equal)
116112
{
117-
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
118-
psi_device = kspw_psi;
113+
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nbasis, true);
114+
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
115+
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
119116
}
120-
else
117+
else if (PARAM.inp.precision == "single")
121118
{
122-
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nbasis, true);
123-
psi_device = kspw_psi;
119+
if (PARAM.inp.device == "cpu")
120+
{
121+
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
122+
psi_device = kspw_psi;
123+
}
124+
else
125+
{
126+
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nbasis, true);
127+
psi_device = kspw_psi;
128+
}
124129
}
125130
}
126131

@@ -134,58 +139,90 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
134139

135140
//! Update Hamiltonian from other kpoint to the given one
136141
p_hamilt->updateHk(ik);
137-
138-
//! initialize psi_cpu
139-
this->psi_initer->init_psig(psi_cpu->get_pointer(), ik);
140-
if (psi_device->get_pointer() != psi_cpu->get_pointer())
142+
if (fill)
141143
{
142-
syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
143-
}
144-
145-
std::vector<typename GetTypeReal<T>::type> etatom(nbands_start, 0.0);
144+
//! initialize psi_cpu
145+
this->psi_initer->init_psig(psi_cpu->get_pointer(), ik);
146+
if (psi_device->get_pointer() != psi_cpu->get_pointer())
147+
{
148+
syncmem_h2d_op()(psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
149+
}
146150

147-
if (this->ks_solver == "cg")
148-
{
149-
if (not_equal)
151+
if (this->ks_solver == "cg")
150152
{
151-
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
152-
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
153-
psi_device->get_pointer(),
154-
nbands_start,
155-
nbasis,
156-
*(kspw_psi),
157-
etatom.data());
153+
std::vector<typename GetTypeReal<T>::type> etatom(nbands_start, 0.0);
154+
if (not_equal)
155+
{
156+
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be
157+
// different
158+
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
159+
psi_device->get_pointer(),
160+
nbands_start,
161+
nbasis,
162+
*(kspw_psi),
163+
etatom.data());
164+
}
165+
else
166+
{
167+
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
168+
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
169+
*psi_device,
170+
*kspw_psi,
171+
etatom.data(),
172+
nbands_start);
173+
}
158174
}
159-
else
175+
else // dav, bpcg
160176
{
161-
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
162-
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
163-
*psi_device,
164-
*kspw_psi,
165-
etatom.data(),
166-
nbands_start);
177+
if (psi_device->get_pointer() != kspw_psi->get_pointer())
178+
{
179+
syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands_l * nbasis);
180+
}
167181
}
168182
}
169-
else // dav, bpcg
183+
#ifdef __MPI
184+
if (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1)
170185
{
171-
if (psi_device->get_pointer() != kspw_psi->get_pointer())
186+
std::vector<int> sendcounts(PARAM.inp.bndpar);
187+
std::vector<int> displs(PARAM.inp.bndpar);
188+
MPI_Allgather(&nbands_l, 1, MPI_INT, sendcounts.data(), 1, MPI_INT, BP_WORLD);
189+
displs[0] = 0;
190+
sendcounts[0] *= nbasis;
191+
for (int i = 1; i < PARAM.inp.bndpar; i++)
192+
{
193+
sendcounts[i] *= nbasis;
194+
displs[i] = displs[i - 1] + sendcounts[i - 1];
195+
}
196+
if (GlobalV::MY_BNDGROUP == 0)
172197
{
173-
syncmem_complex_op()(kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis);
198+
for (int ip = 1; ip < PARAM.inp.bndpar; ++ip)
199+
{
200+
Parallel_Common::send_data(psi_cpu->get_pointer() + displs[ip], sendcounts[ip], ip, 0, BP_WORLD);
201+
}
174202
}
203+
else
204+
{
205+
MPI_Status status;
206+
Parallel_Common::recv_dev<T, Device>(kspw_psi->get_pointer(), nbands_l * nbasis, 0, 0, BP_WORLD, &status);
207+
}
175208
}
209+
#endif
176210
} // end k-point loop
177211

178-
if (not_equal)
212+
if (fill)
179213
{
180-
delete psi_cpu;
181-
if(PARAM.inp.device == "gpu")
214+
if (not_equal)
182215
{
183-
delete psi_device;
216+
delete psi_cpu;
217+
if (PARAM.inp.device == "gpu")
218+
{
219+
delete psi_device;
220+
}
221+
}
222+
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
223+
{
224+
delete psi_cpu;
184225
}
185-
}
186-
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
187-
{
188-
delete psi_cpu;
189226
}
190227

191228
ModuleBase::timer::tick("PSIInit", "initialize_psi");
@@ -203,7 +240,11 @@ void PSIInit<T, Device>::initialize_lcao_in_pw(Psi<T>* psi_local, std::ofstream&
203240
}
204241
}
205242

206-
void allocate_psi(Psi<std::complex<double>>*& psi, const int& nks, const std::vector<int>& ngk, const int& nbands, const int& npwx)
243+
void allocate_psi(Psi<std::complex<double>>*& psi,
244+
const int& nks,
245+
const std::vector<int>& ngk,
246+
const int& nbands,
247+
const int& npwx)
207248
{
208249
assert(npwx > 0);
209250
assert(nks > 0);
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
etotref -4869.7470520365577613
2-
etotperatomref -2434.8735260183
3-
totalforceref 5.202524
4-
totalstressref 37241.827525
1+
etotref -4869.7470518349809936
2+
etotperatomref -2434.8735259175
3+
totalforceref 5.207670
4+
totalstressref 37241.465646
55
pointgroupref C_1
66
spacegroupref C_1
77
nksibzref 8
8-
totaltimeref 4.25
8+
totaltimeref 10.28

tests/integrate/187_PW_SDFT_MALL_BPCG_GPU/INPUT

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pseudo_dir ../../PP_ORB
1313
kpar 1
1414
bndpar 2
1515

16-
nbands 7
16+
nbands 11
1717
nbands_sto all
1818

1919
nche_sto 120
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
etotref -96.93611634
2-
etotperatomref -48.46805817
3-
totalforceref 248.97836000
4-
totalstressref 230453.08774800
5-
totaltimeref 6.37
1+
etotref -96.9361190965003630
2+
etotperatomref -48.4680595483
3+
totalforceref 248.979444
4+
totalstressref 230453.604050
5+
totaltimeref 6.44

0 commit comments

Comments
 (0)