Skip to content

Commit 86b09ea

Browse files
denghuiludyzheng
authored andcommitted
GPU: Add multi device support for HPsi(veff_pw) (#1456)
* add multi device support for hpsi(veff_pw) * add UTs * fix compilation errors with cuda environment * remove cuda flags * fix CI error * fix Intel compilation error
1 parent d483952 commit 86b09ea

30 files changed

+997
-121
lines changed

source/module_base/test/memory_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class MemoryTest : public testing::Test
2424
{
2525
protected:
26-
// definition according to ../memory.cpp
26+
// definition according to ../memory_psi.cpp
2727
double factor = 1.0 / 1024.0 / 1024.0; // MB
2828
double complex_matrix_mem = 2*sizeof(double) * factor; // byte to MB
2929
double double_mem = sizeof(double) * factor;

source/module_elecstate/test/CMakeLists.txt

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ if(ENABLE_LCAO)
1313
../../src_pw/structure_factor.cpp ../../src_pw/pw_complement.cpp
1414
../../src_pw/klist.cpp ../../src_parallel/parallel_kpoints.cpp ../../src_pw/occupy.cpp
1515
)
16+
if(USE_CUDA)
17+
target_link_libraries(EState_updaterhok_pw cufft)
18+
endif()
1619

1720
install(DIRECTORY support DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
1821

@@ -35,6 +38,9 @@ if(ENABLE_LCAO)
3538
../../src_pdiag/pdiag_common.cpp
3639
../../src_io/output.cpp ../../src_pw/soc.cpp ../../src_io/read_rho.cpp
3740
)
41+
if(USE_CUDA)
42+
target_link_libraries(EState_psiToRho_lcao cufft)
43+
endif()
3844
target_compile_definitions(EState_psiToRho_lcao PRIVATE __MPI)
3945
install(FILES elecstate_lcao_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
4046

@@ -83,9 +89,7 @@ add_library(
8389
../../module_base/ylm.cpp
8490
)
8591

86-
add_library(
87-
planewave_serial
88-
OBJECT
92+
list(APPEND planewave_serial_srcs
8993
../../module_pw/fft.cpp
9094
../../module_pw/pw_basis.cpp
9195
../../module_pw/pw_basis_k.cpp
@@ -96,6 +100,17 @@ add_library(
96100
../../module_pw/pw_init.cpp
97101
../../module_pw/pw_transform.cpp
98102
../../module_pw/pw_transform_k.cpp
103+
../../module_pw/src/pw_multi_device.cpp
104+
)
105+
106+
if (USE_CUDA)
107+
list(APPEND planewave_serial_srcs ../../module_pw/src/cuda/pw_multi_device.cu)
108+
endif()
109+
110+
add_library(
111+
planewave_serial
112+
OBJECT
113+
${planewave_serial_srcs}
99114
)
100115

101116
if(ENABLE_COVERAGE)

source/module_hamilt/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ list(APPEND objects
66
hamilt_pw.cpp
77
src/ekinetic.cpp
88
src/nonlocal.cpp
9+
src/veff.cpp
910
ks_pw/ekinetic_pw.cpp
1011
ks_pw/veff_pw.cpp
1112
ks_pw/nonlocal_pw.cpp
@@ -30,9 +31,9 @@ if(ENABLE_LCAO)
3031
endif()
3132

3233
if (USE_CUDA)
33-
list(APPEND objects src/cuda/ekinetic.cu src/cuda/nonlocal.cu)
34+
list(APPEND objects src/cuda/ekinetic.cu src/cuda/nonlocal.cu src/cuda/veff.cu)
3435
elseif(USE_ROCM)
35-
list(APPEND objects src/rocm/ekinetic.cu src/cuda/nonlocal.cu)
36+
list(APPEND objects src/rocm/ekinetic.cu src/cuda/nonlocal.cu src/rocm/veff.cu)
3637
endif()
3738

3839
add_library(
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef MODULE_HAMILT_VEFF_H
2+
#define MODULE_HAMILT_VEFF_H
3+
4+
#include "module_psi/psi.h"
5+
#include <complex>
6+
7+
namespace hamilt {
8+
template <typename FPTYPE, typename Device>
9+
struct veff_pw_op {
10+
void operator() (
11+
const Device* dev,
12+
const int& size,
13+
std::complex<FPTYPE>* out,
14+
const FPTYPE* in);
15+
};
16+
17+
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
18+
// Partially specialize functor for psi::GpuDevice.
19+
template <typename FPTYPE>
20+
struct veff_pw_op<FPTYPE, psi::DEVICE_GPU> {
21+
void operator() (
22+
const psi::DEVICE_GPU* dev,
23+
const int& size,
24+
std::complex<FPTYPE>* out,
25+
const FPTYPE* in);
26+
};
27+
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
28+
} // namespace hamilt
29+
#endif //MODULE_HAMILT_VEFF_H

source/module_hamilt/ks_pw/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ list(APPEND operator_ks_pw_srcs
88
../operator.cpp
99
../src/ekinetic.cpp
1010
../src/nonlocal.cpp
11+
../src/veff.cpp
1112
)
1213

1314
if (USE_CUDA)
14-
list(APPEND operator_ks_pw_srcs ../src/cuda/ekinetic.cu ../src/cuda/nonlocal.cu)
15+
list(APPEND operator_ks_pw_srcs ../src/cuda/ekinetic.cu ../src/cuda/nonlocal.cu ../src/cuda/veff.cu)
1516
elseif(USE_ROCM)
16-
list(APPEND operator_ks_pw_srcs ../src/rocm/ekinetic.cu ../src/rocm/nonlocal.cu)
17+
list(APPEND operator_ks_pw_srcs ../src/rocm/ekinetic.cu ../src/rocm/nonlocal.cu ../src/rocm/veff.cu)
1718
endif()
1819

1920
add_library(

source/module_hamilt/ks_pw/nonlocal_pw.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_base/timer.h"
55
#include "src_parallel/parallel_reduce.h"
66
#include "module_base/tool_quit.h"
7+
#include "module_psi/include/device.h"
78

89
using hamilt::Nonlocal;
910
using hamilt::OperatorPW;
@@ -19,6 +20,9 @@ Nonlocal<OperatorPW<FPTYPE, Device>>::Nonlocal(
1920
this->isk = isk_in;
2021
this->ppcell = ppcell_in;
2122
this->ucell = ucell_in;
23+
this->deeq = psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice ?
24+
this->ppcell->d_deeq : // for GpuDevice
25+
this->ppcell->deeq.ptr; // for CpuDevice
2226
if( this->isk == nullptr || this->ppcell == nullptr || this->ucell == nullptr)
2327
{
2428
ModuleBase::WARNING_QUIT("NonlocalPW", "Constuctor of Operator::NonlocalPW is failed, please check your code!");
@@ -78,7 +82,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
7882
this->ucell->atoms[it].na, m, nproj, // four loop size
7983
sum, iat, current_spin, nkb, // additional index params
8084
this->ppcell->deeq.getBound2(), this->ppcell->deeq.getBound3(), this->ppcell->deeq.getBound4(), // realArray operator()
81-
this->ppcell->deeq.ptr, // array of data
85+
this->deeq, // array of data
8286
this->ps, this->becp); // array of data
8387
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
8488
// for (int ia = 0; ia < this->ucell->atoms[it].na; ia++)
@@ -104,6 +108,9 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
104108
}
105109
else
106110
{
111+
#if defined(__CUDA) || defined(__ROCM)
112+
ModuleBase::WARNING_QUIT("NonlocalPW", " gpu implementation of this->npol != 1 is not supported currently !!! ");
113+
#endif
107114
for (int it = 0; it < this->ucell->ntype; it++)
108115
{
109116
int psind = 0;

source/module_hamilt/ks_pw/nonlocal_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class Nonlocal<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
5656
mutable std::complex<FPTYPE>* becp = nullptr;
5757
mutable std::complex<FPTYPE> *ps = nullptr;
5858
Device* ctx = {};
59+
FPTYPE * deeq = nullptr;
5960
// using nonlocal_op = nonlocal_pw_op<FPTYPE, Device>;
6061
using gemv_op = hsolver::gemv_op<FPTYPE, Device>;
6162
using gemm_op = hsolver::gemm_op<FPTYPE, Device>;

source/module_hamilt/ks_pw/veff_pw.cpp

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "veff_pw.h"
22

33
#include "module_base/timer.h"
4-
#include "src_pw/global.h"
54
#include "module_base/tool_quit.h"
65

76
using hamilt::Veff;
@@ -15,14 +14,29 @@ Veff<OperatorPW<FPTYPE, Device>>::Veff(
1514
{
1615
this->cal_type = pw_veff;
1716
this->isk = isk_in;
18-
this->veff = veff_in;
17+
// this->veff = veff_in;
18+
// TODO: add an GPU veff array
19+
this->veff = veff_in[0].c;
20+
this->veff_col = veff_in[0].nc;
1921
this->wfcpw = wfcpw_in;
20-
if( this->isk == nullptr || this->veff == nullptr || this->wfcpw == nullptr)
21-
{
22+
resize_memory_op()(this->ctx, this->porter, this->wfcpw->nmaxgr);
23+
if (this->npol != 1) {
24+
resize_memory_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr);
25+
}
26+
if (this->isk == nullptr || this->veff == nullptr || this->wfcpw == nullptr) {
2227
ModuleBase::WARNING_QUIT("VeffPW", "Constuctor of Operator::VeffPW is failed, please check your code!");
2328
}
2429
}
2530

31+
template<typename FPTYPE, typename Device>
32+
Veff<OperatorPW<FPTYPE, Device>>::~Veff()
33+
{
34+
delete_memory_op()(this->ctx, this->porter);
35+
if (this->npol != 1) {
36+
delete_memory_op()(this->ctx, this->porter1);
37+
}
38+
}
39+
2640
template<typename FPTYPE, typename Device>
2741
void Veff<OperatorPW<FPTYPE, Device>>::act(
2842
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
@@ -37,65 +51,64 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
3751
const int current_spin = this->isk[this->ik];
3852
this->npol = psi_in->npol;
3953

40-
std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
54+
// std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
4155
for (int ib = 0; ib < n_npwx; ib += this->npol)
4256
{
4357
if (this->npol == 1)
4458
{
45-
wfcpw->recip2real(tmpsi_in, porter, this->ik);
59+
// wfcpw->recip2real(tmpsi_in, porter, this->ik);
60+
wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
4661
// NOTICE: when MPI threads are larger than number of Z grids
4762
// veff would contain nothing, and nothing should be done in real space
4863
// but the 3DFFT can not be skipped, it will cause hanging
49-
if(this->veff->nc != 0)
64+
if(this->veff_col != 0)
5065
{
51-
const FPTYPE* current_veff = &(this->veff[0](current_spin, 0));
52-
for (int ir = 0; ir < this->veff->nc; ++ir)
53-
{
54-
porter[ir] *= current_veff[ir];
55-
}
66+
// const FPTYPE* current_veff = &(this->veff[0](current_spin, 0));
67+
// for (int ir = 0; ir < this->veff->nc; ++ir)
68+
// {
69+
// porter[ir] *= current_veff[ir];
70+
// }
71+
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
5672
}
57-
wfcpw->real2recip(porter, tmhpsi, this->ik, true);
73+
// wfcpw->real2recip(porter, tmhpsi, this->ik, true);
74+
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
5875
}
5976
else
6077
{
61-
std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
78+
// std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
6279
// fft to real space and doing things.
63-
wfcpw->recip2real(tmpsi_in, porter, this->ik);
64-
wfcpw->recip2real(tmpsi_in + this->max_npw, porter1, this->ik);
80+
wfcpw->recip2real(tmpsi_in, this->porter, this->ik);
81+
wfcpw->recip2real(tmpsi_in + this->max_npw, this->porter1, this->ik);
6582
std::complex<FPTYPE> sup, sdown;
66-
if(this->veff->nc != 0)
83+
if(this->veff_col != 0)
6784
{
6885
const FPTYPE* current_veff[4];
6986
for(int is=0;is<4;is++)
7087
{
71-
current_veff[is] = &(this->veff[0](is, 0));
88+
current_veff[is] = this->veff + is * this->veff_col;
7289
}
73-
for (int ir = 0; ir < this->veff->nc; ir++)
90+
for (int ir = 0; ir < this->veff_col; ir++)
7491
{
75-
sup = porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
76-
+ porter1[ir]
92+
sup = this->porter[ir] * (current_veff[0][ir] + current_veff[3][ir])
93+
+ this->porter1[ir]
7794
* (current_veff[1][ir]
7895
- std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
79-
sdown = porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
80-
+ porter[ir]
96+
sdown = this->porter1[ir] * (current_veff[0][ir] - current_veff[3][ir])
97+
+ this->porter[ir]
8198
* (current_veff[1][ir]
8299
+ std::complex<FPTYPE>(0.0, 1.0) * current_veff[2][ir]);
83-
porter[ir] = sup;
84-
porter1[ir] = sdown;
100+
this->porter[ir] = sup;
101+
this->porter1[ir] = sdown;
85102
}
86103
}
87104
// (3) fft back to G space.
88-
wfcpw->real2recip(porter, tmhpsi, this->ik, true);
89-
wfcpw->real2recip(porter1, tmhpsi + this->max_npw, this->ik, true);
90-
91-
delete[] porter1;
105+
wfcpw->real2recip(this->porter, tmhpsi, this->ik, true);
106+
wfcpw->real2recip(this->porter1, tmhpsi + this->max_npw, this->ik, true);
92107
}
93108
tmhpsi += this->max_npw * this->npol;
94109
tmpsi_in += this->max_npw * this->npol;
95110
}
96-
delete[] porter;
97111
ModuleBase::timer::tick("Operator", "VeffPW");
98-
return;
99112
}
100113

101114
namespace hamilt{

source/module_hamilt/ks_pw/veff_pw.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "operator_pw.h"
55
#include "module_base/matrix.h"
66
#include "module_pw/pw_basis_k.h"
7+
#include "module_hamilt/include/veff.h"
78

89
namespace hamilt {
910

@@ -23,7 +24,7 @@ class Veff<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
2324
public:
2425
Veff(const int* isk_in,const ModuleBase::matrix* veff_in,ModulePW::PW_Basis_K* wfcpw_in);
2526

26-
virtual ~Veff(){};
27+
virtual ~Veff();
2728

2829
virtual void act (
2930
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
@@ -40,9 +41,20 @@ class Veff<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>
4041

4142
const int* isk = nullptr;
4243

43-
const ModuleBase::matrix* veff = nullptr;
4444

4545
ModulePW::PW_Basis_K* wfcpw = nullptr;
46+
47+
Device* ctx = {};
48+
49+
int veff_col = 0;
50+
FPTYPE *veff = nullptr;
51+
std::complex<FPTYPE> *porter = nullptr;
52+
std::complex<FPTYPE> *porter1 = nullptr;
53+
54+
using veff_op = veff_pw_op<FPTYPE, Device>;
55+
56+
using resize_memory_op = psi::memory::resize_memory_op<std::complex<FPTYPE>, Device>;
57+
using delete_memory_op = psi::memory::delete_memory_op<std::complex<FPTYPE>, Device>;
4658
};
4759

4860
} // namespace hamilt
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "module_hamilt/include/veff.h"
2+
#include <complex>
3+
#include <thrust/complex.h>
4+
#include "cuda_runtime.h"
5+
6+
namespace hamilt{
7+
8+
#define THREADS_PER_BLOCK 256
9+
10+
template <typename FPTYPE>
11+
__global__ void veff_pw(
12+
const int size,
13+
thrust::complex<FPTYPE>* out,
14+
const FPTYPE* in)
15+
{
16+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
17+
if(idx >= size) {return;}
18+
out[idx] *= in[idx];
19+
}
20+
21+
template <typename FPTYPE>
22+
void veff_pw_op<FPTYPE, psi::DEVICE_GPU>::operator() (
23+
const psi::DEVICE_GPU* dev,
24+
const int& size,
25+
std::complex<FPTYPE>* out,
26+
const FPTYPE* in)
27+
{
28+
const int block = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
29+
veff_pw<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
30+
size, // control params
31+
reinterpret_cast<thrust::complex<FPTYPE>*>(out), // array of data
32+
in); // array of data
33+
// cpu part:
34+
// for (int ir = 0; ir < size; ++ir)
35+
// {
36+
// out[ir] *= in[ir];
37+
// }
38+
// >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
39+
}
40+
41+
template struct veff_pw_op<double, psi::DEVICE_GPU>;
42+
43+
} // namespace hamilt

0 commit comments

Comments
 (0)