Skip to content

Commit a8d71b2

Browse files
Undo cuda aware mpi
1 parent f7059d5 commit a8d71b2

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

source/source_basis/module_pw/pw_basis.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@ class PW_Basis
248248
ModuleBase::FFT_Bundle fft_bundle;
249249
//The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform).
250250

251-
template <typename FPTYPE>
252-
void real2recip(const FPTYPE* in,
253-
std::complex<FPTYPE>* out,
254-
const bool add = false,
255-
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
251+
template <typename FPTYPE>
252+
void real2recip(const FPTYPE* in,
253+
std::complex<FPTYPE>* out,
254+
const bool add = false,
255+
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
256256
template <typename FPTYPE>
257257
void real2recip(const std::complex<FPTYPE>* in,
258258
std::complex<FPTYPE>* out,
@@ -269,16 +269,16 @@ class PW_Basis
269269
const bool add = false,
270270
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
271271

272+
template <typename FPTYPE>
273+
void real2recip_gpu(const FPTYPE* in,
274+
std::complex<FPTYPE>* out,
275+
const bool add = false,
276+
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
272277
template <typename FPTYPE>
273-
void real2recip_gpu(const FPTYPE* in,
274-
std::complex<FPTYPE>* out,
275-
const bool add = false,
276-
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
277-
template <typename FPTYPE>
278-
void real2recip_gpu(const std::complex<FPTYPE>* in,
279-
std::complex<FPTYPE>* out,
280-
const bool add = false,
281-
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
278+
void real2recip_gpu(const std::complex<FPTYPE>* in,
279+
std::complex<FPTYPE>* out,
280+
const bool add = false,
281+
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
282282
template <typename FPTYPE>
283283
void recip2real_gpu(const std::complex<FPTYPE>* in,
284284
FPTYPE* out,

source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,26 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
244244

245245
// kpar: to real and bcast between same rank_in_pool
246246
// auto request = MPI_REQUEST_
247+
T* psi_mq_real_cpu = new T[wfcpw->nrxx];
247248
if (iq_pool == GlobalV::MY_POOL)
248249
{
249250
const T* psi_mq = get_pw(m_iband, iq_loc);
250251
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc);
252+
cudaMemcpy(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx, cudaMemcpyDeviceToHost);
251253
// send
252254
}
253255
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
256+
if (iq_pool == GlobalV::MY_POOL)
257+
{
258+
const T* psi_mq = get_pw(m_iband, iq_loc);
259+
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc);
260+
// send
261+
}
262+
else
263+
{
264+
cudaMemcpy(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx, cudaMemcpyHostToDevice);
265+
}
266+
delete[] psi_mq_real_cpu;
254267
// std::cout << "psi_mq_real[0]: " << psi_mq_real[0] << std::endl;
255268
// if (GlobalV::RANK_IN_POOL == 1)
256269
// {

0 commit comments

Comments
 (0)