Skip to content

Commit f2e91bd

Browse files
authored
Feature: make force and stress of sDFT support GPU (#5487)
* refactor force in sdft * refactor stress in sDFT * make stress_ekin GPU * finish sdft GPU * fix compile * add annotations * fix bug of stress and force * modify
1 parent 5b1777c commit f2e91bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1645
-1073
lines changed

source/Makefile.Objects

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ OBJS_PARALLEL=parallel_common.o\
617617
parallel_grid.o\
618618
parallel_kpoints.o\
619619
parallel_reduce.o\
620+
parallel_device.o
620621

621622
OBJS_SRCPW=H_Ewald_pw.o\
622623
dnrm2.o\
@@ -640,6 +641,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
640641
forces_cc.o\
641642
forces_scc.o\
642643
fs_nonlocal_tools.o\
644+
fs_kin_tools.o\
643645
force_op.o\
644646
stress_op.o\
645647
wf_op.o\

source/module_base/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_library(
5050
parallel_global.cpp
5151
parallel_comm.cpp
5252
parallel_reduce.cpp
53+
parallel_device.cpp
5354
spherical_bessel_transformer.cpp
5455
cubic_spline.cpp
5556
module_mixing/mixing_data.cpp

source/module_base/module_device/device.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,27 +191,30 @@ else { return "cpu";
191191
}
192192
}
193193

194-
int get_device_kpar(const int &kpar) {
194+
int get_device_kpar(const int& kpar, const int& bndpar)
195+
{
195196
#if __MPI && (__CUDA || __ROCM)
196-
int temp_nproc;
197-
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
198-
if (temp_nproc != kpar) {
199-
ModuleBase::WARNING("Input_conv",
200-
"None kpar set in INPUT file, auto set kpar value.");
201-
}
202-
// GlobalV::KPAR = temp_nproc;
203-
// band the CPU processor to the devices
204-
int node_rank = base_device::information::get_node_rank();
197+
int temp_nproc = 0;
198+
int new_kpar = kpar;
199+
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
200+
if (temp_nproc != kpar * bndpar)
201+
{
202+
new_kpar = temp_nproc / bndpar;
203+
ModuleBase::WARNING("Input_conv", "kpar is not compatible with the number of processors, auto set kpar value.");
204+
}
205+
206+
// get the CPU rank of current node
207+
int node_rank = base_device::information::get_node_rank();
205208

206-
int device_num = -1;
209+
int device_num = -1;
207210
#if defined(__CUDA)
208-
cudaGetDeviceCount(&device_num);
209-
cudaSetDevice(node_rank % device_num);
211+
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
212+
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
210213
#elif defined(__ROCM)
211214
hipGetDeviceCount(&device_num);
212215
hipSetDevice(node_rank % device_num);
213216
#endif
214-
return temp_nproc;
217+
return new_kpar;
215218
#endif
216219
return kpar;
217220
}

source/module_base/module_device/device.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ std::string get_device_info(std::string device_flag);
4040
* @brief Get the device kpar object
4141
* for module_io GlobalV::KPAR
4242
*/
43-
int get_device_kpar(const int& kpar);
43+
int get_device_kpar(const int& kpar, const int& bndpar);
4444

4545
/**
4646
* @brief Get the device flag object
@@ -50,6 +50,12 @@ std::string get_device_flag(const std::string& device,
5050
const std::string& basis_type);
5151

5252
#if __MPI
53+
/**
54+
* @brief Get the rank of current node
55+
* Note that GPU can only be binded with CPU in the same node
56+
*
57+
* @return int
58+
*/
5359
int get_node_rank();
5460
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
5561
int stringCmp(const void* a, const void* b);

source/module_base/parallel_common.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@ void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 201
1111
{
1212
int size = object.size();
1313
MPI_Bcast(&size, 1, MPI_INT, 0, MPI_COMM_WORLD);
14-
char* swap = new char[size + 1];
14+
1515
int my_rank;
1616
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
17-
if (0 == my_rank)
18-
strcpy(swap, object.c_str());
19-
MPI_Bcast(swap, size + 1, MPI_CHAR, 0, MPI_COMM_WORLD);
17+
2018
if (0 != my_rank)
21-
object = static_cast<std::string>(swap);
22-
delete[] swap;
19+
{
20+
object.resize(size);
21+
}
22+
23+
MPI_Bcast(&object[0], size, MPI_CHAR, 0, MPI_COMM_WORLD);
2324
return;
2425
}
2526

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "parallel_device.h"
2+
#ifdef __MPI
3+
namespace Parallel_Common
4+
{
5+
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
6+
{
7+
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
8+
}
9+
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
10+
{
11+
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
12+
}
13+
void bcast_data(double* object, const int& n, const MPI_Comm& comm)
14+
{
15+
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
16+
}
17+
void bcast_data(float* object, const int& n, const MPI_Comm& comm)
18+
{
19+
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
20+
}
21+
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
22+
{
23+
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_DOUBLE, MPI_SUM, comm);
24+
}
25+
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
26+
{
27+
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_FLOAT, MPI_SUM, comm);
28+
}
29+
void reduce_data(double* object, const int& n, const MPI_Comm& comm)
30+
{
31+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, comm);
32+
}
33+
void reduce_data(float* object, const int& n, const MPI_Comm& comm)
34+
{
35+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm);
36+
}
37+
}
38+
#endif

source/module_base/parallel_device.h

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,34 @@
1+
#ifndef __PARALLEL_DEVICE_H__
2+
#define __PARALLEL_DEVICE_H__
13
#ifdef __MPI
24
#include "mpi.h"
35
#include "module_base/module_device/device.h"
6+
#include "module_base/module_device/memory_op.h"
47
#include <complex>
5-
#include <string>
6-
#include <vector>
78
namespace Parallel_Common
89
{
9-
void bcast_complex(std::complex<double>* object, const int& n, const MPI_Comm& comm)
10-
{
11-
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
12-
}
13-
void bcast_complex(std::complex<float>* object, const int& n, const MPI_Comm& comm)
14-
{
15-
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
16-
}
17-
void bcast_real(double* object, const int& n, const MPI_Comm& comm)
18-
{
19-
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
20-
}
21-
void bcast_real(float* object, const int& n, const MPI_Comm& comm)
22-
{
23-
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
24-
}
10+
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
11+
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
12+
void bcast_data(double* object, const int& n, const MPI_Comm& comm);
13+
void bcast_data(float* object, const int& n, const MPI_Comm& comm);
14+
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
15+
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
16+
void reduce_data(double* object, const int& n, const MPI_Comm& comm);
17+
void reduce_data(float* object, const int& n, const MPI_Comm& comm);
2518

26-
template <typename T, typename Device>
2719
/**
28-
* @brief bcast complex in Device
20+
* @brief bcast data in Device
2921
*
22+
* @tparam T: float, double, std::complex<float>, std::complex<double>
23+
* @tparam Device
3024
* @param ctx Device ctx
3125
* @param object complex arrays in Device
3226
* @param n the size of complex arrays
3327
* @param comm MPI_Comm
3428
* @param tmp_space tmp space in CPU
3529
*/
36-
void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
30+
template <typename T, typename Device>
31+
void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
3732
{
3833
const base_device::DEVICE_CPU* cpu_ctx = {};
3934
T* object_cpu = nullptr;
@@ -56,7 +51,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
5651
object_cpu = object;
5752
}
5853

59-
bcast_complex(object_cpu, n, comm);
54+
bcast_data(object_cpu, n, comm);
6055

6156
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
6257
{
@@ -70,7 +65,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
7065
}
7166

7267
template <typename T, typename Device>
73-
void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
68+
void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
7469
{
7570
const base_device::DEVICE_CPU* cpu_ctx = {};
7671
T* object_cpu = nullptr;
@@ -93,7 +88,7 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
9388
object_cpu = object;
9489
}
9590

96-
bcast_real(object_cpu, n, comm);
91+
reduce_data(object_cpu, n, comm);
9792

9893
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
9994
{
@@ -105,7 +100,9 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
105100
}
106101
return;
107102
}
103+
108104
}
109105

110106

107+
#endif
111108
#endif

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
1414
ModuleBase::TITLE(this->classname, "psiToRho");
1515
ModuleBase::timer::tick(this->classname, "psiToRho");
1616
const int nspin = PARAM.inp.nspin;
17+
for (int is = 0; is < nspin; is++)
18+
{
19+
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
20+
}
1721

1822
if (GlobalV::MY_STOGROUP == 0)
1923
{
20-
for (int is = 0; is < nspin; is++)
21-
{
22-
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
23-
}
24-
2524
for (int ik = 0; ik < psi.get_nk(); ++ik)
2625
{
2726
psi.fix_k(ik);

source/module_elecstate/module_charge/charge_extra.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,20 @@ void Charge_Extra::Init_CE(const int& nspin, const int& natom, const int& nrxx,
4747

4848
if (pot_order > 0)
4949
{
50-
delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
51-
delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
52-
delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
50+
// delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
51+
// delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
52+
// delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
53+
// qianrui replace the above code with the following code.
54+
// The above code cannot passed valgrind tests, which has an invalid read of size 32.
55+
delta_rho1.resize(this->nspin);
56+
delta_rho2.resize(this->nspin);
57+
delta_rho3.resize(this->nspin);
58+
for (int is = 0; is < this->nspin; is++)
59+
{
60+
delta_rho1[is].resize(nrxx, 0.0);
61+
delta_rho2[is].resize(nrxx, 0.0);
62+
delta_rho3[is].resize(nrxx, 0.0);
63+
}
5364
}
5465

5566
if(pot_order == 3)

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,6 @@ void ESolver_KS_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
623623
&this->sf,
624624
&this->kv,
625625
this->pw_wfc,
626-
this->psi,
627626
this->__kspw_psi);
628627

629628
// external stress

0 commit comments

Comments
 (0)