Skip to content

Commit bbd6ff0

Browse files
committed
prepare for PR
2 parents 64ae540 + cb53505 commit bbd6ff0

File tree

122 files changed

+2901
-1626
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+2901
-1626
lines changed

source/Makefile.Objects

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
261261
lcao_gets.o\
262262
lcao_others.o\
263263
lcao_init_after_vc.o\
264-
lcao_fun.o\
265264

266265
OBJS_GINT=gint.o\
267266
gint_gamma_env.o\
@@ -620,6 +619,7 @@ OBJS_PARALLEL=parallel_common.o\
620619
parallel_grid.o\
621620
parallel_kpoints.o\
622621
parallel_reduce.o\
622+
parallel_device.o
623623

624624
OBJS_SRCPW=H_Ewald_pw.o\
625625
dnrm2.o\
@@ -643,6 +643,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
643643
forces_cc.o\
644644
forces_scc.o\
645645
fs_nonlocal_tools.o\
646+
fs_kin_tools.o\
646647
force_op.o\
647648
stress_op.o\
648649
wf_op.o\

source/module_base/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_library(
5050
parallel_global.cpp
5151
parallel_comm.cpp
5252
parallel_reduce.cpp
53+
parallel_device.cpp
5354
spherical_bessel_transformer.cpp
5455
cubic_spline.cpp
5556
module_mixing/mixing_data.cpp

source/module_base/formatter.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ class FmtCore
145145
[&delim](const std::string& acc, const std::string& s) { return acc + delim + s; });
146146
}
147147

148+
static std::string upper(const std::string& in)
149+
{
150+
std::string dst = in;
151+
std::transform(dst.begin(), dst.end(), dst.begin(), ::toupper);
152+
return dst;
153+
}
154+
155+
static std::string lower(const std::string& in)
156+
{
157+
std::string dst = in;
158+
std::transform(dst.begin(), dst.end(), dst.begin(), ::tolower);
159+
return dst;
160+
}
161+
148162
private:
149163
std::string fmt_;
150164
template<typename T>

source/module_base/memory.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ int Memory::short_memory = sizeof(short); // 2.0 Byte
2323

2424
int Memory::n_memory = 1000;
2525
int Memory::n_now = 0;
26-
bool Memory::init_flag = false;
26+
bool Memory::init_flag = false;
2727

2828
#if defined(__CUDA) || defined(__ROCM)
2929

@@ -365,18 +365,15 @@ void Memory::finish(std::ofstream &ofs)
365365
delete[] name_gpu;
366366
delete[] class_name_gpu;
367367
delete[] consume_gpu;
368+
init_flag_gpu = false;
368369
}
369370
#endif
370371
return;
371372
}
372373

373374
void Memory::print_all(std::ofstream &ofs)
374375
{
375-
if(!init_flag
376-
#if defined(__CUDA) || defined(__ROCM)
377-
&& !init_flag_gpu
378-
#endif
379-
)
376+
if(!init_flag)
380377
{
381378
return;
382379
}
@@ -437,6 +434,11 @@ void Memory::print_all(std::ofstream &ofs)
437434
}
438435

439436
#if defined(__CUDA) || defined(__ROCM)
437+
if(!init_flag_gpu)
438+
{
439+
return;
440+
}
441+
440442
ofs <<"\n NAME-------------------------|GPU MEMORY(MB)----" << std::endl;
441443
ofs <<std::setw(30)<< "total" << std::setw(15) <<std::setprecision(4)<< Memory::total_gpu << std::endl;
442444

source/module_base/module_device/device.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,27 +191,30 @@ else { return "cpu";
191191
}
192192
}
193193

194-
int get_device_kpar(const int &kpar) {
194+
int get_device_kpar(const int& kpar, const int& bndpar)
195+
{
195196
#if __MPI && (__CUDA || __ROCM)
196-
int temp_nproc;
197-
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
198-
if (temp_nproc != kpar) {
199-
ModuleBase::WARNING("Input_conv",
200-
"None kpar set in INPUT file, auto set kpar value.");
201-
}
202-
// GlobalV::KPAR = temp_nproc;
203-
// band the CPU processor to the devices
204-
int node_rank = base_device::information::get_node_rank();
197+
int temp_nproc = 0;
198+
int new_kpar = kpar;
199+
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
200+
if (temp_nproc != kpar * bndpar)
201+
{
202+
new_kpar = temp_nproc / bndpar;
203+
ModuleBase::WARNING("Input_conv", "kpar is not compatible with the number of processors, auto set kpar value.");
204+
}
205+
206+
// get the CPU rank of current node
207+
int node_rank = base_device::information::get_node_rank();
205208

206-
int device_num = -1;
209+
int device_num = -1;
207210
#if defined(__CUDA)
208-
cudaGetDeviceCount(&device_num);
209-
cudaSetDevice(node_rank % device_num);
211+
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
212+
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
210213
#elif defined(__ROCM)
211214
hipGetDeviceCount(&device_num);
212215
hipSetDevice(node_rank % device_num);
213216
#endif
214-
return temp_nproc;
217+
return new_kpar;
215218
#endif
216219
return kpar;
217220
}

source/module_base/module_device/device.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ std::string get_device_info(std::string device_flag);
4040
* @brief Get the device kpar object
4141
* for module_io GlobalV::KPAR
4242
*/
43-
int get_device_kpar(const int& kpar);
43+
int get_device_kpar(const int& kpar, const int& bndpar);
4444

4545
/**
4646
* @brief Get the device flag object
@@ -50,6 +50,12 @@ std::string get_device_flag(const std::string& device,
5050
const std::string& basis_type);
5151

5252
#if __MPI
53+
/**
54+
* @brief Get the rank of current node
55+
* Note that GPU can only be binded with CPU in the same node
56+
*
57+
* @return int
58+
*/
5359
int get_node_rank();
5460
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
5561
int stringCmp(const void* a, const void* b);

source/module_base/parallel_common.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@ void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 201
1111
{
1212
int size = object.size();
1313
MPI_Bcast(&size, 1, MPI_INT, 0, MPI_COMM_WORLD);
14-
char* swap = new char[size + 1];
14+
1515
int my_rank;
1616
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
17-
if (0 == my_rank)
18-
strcpy(swap, object.c_str());
19-
MPI_Bcast(swap, size + 1, MPI_CHAR, 0, MPI_COMM_WORLD);
17+
2018
if (0 != my_rank)
21-
object = static_cast<std::string>(swap);
22-
delete[] swap;
19+
{
20+
object.resize(size);
21+
}
22+
23+
MPI_Bcast(&object[0], size, MPI_CHAR, 0, MPI_COMM_WORLD);
2324
return;
2425
}
2526

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "parallel_device.h"
2+
#ifdef __MPI
3+
namespace Parallel_Common
4+
{
5+
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
6+
{
7+
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
8+
}
9+
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
10+
{
11+
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
12+
}
13+
void bcast_data(double* object, const int& n, const MPI_Comm& comm)
14+
{
15+
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
16+
}
17+
void bcast_data(float* object, const int& n, const MPI_Comm& comm)
18+
{
19+
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
20+
}
21+
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
22+
{
23+
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_DOUBLE, MPI_SUM, comm);
24+
}
25+
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
26+
{
27+
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_FLOAT, MPI_SUM, comm);
28+
}
29+
void reduce_data(double* object, const int& n, const MPI_Comm& comm)
30+
{
31+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, comm);
32+
}
33+
void reduce_data(float* object, const int& n, const MPI_Comm& comm)
34+
{
35+
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm);
36+
}
37+
}
38+
#endif

source/module_base/parallel_device.h

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,34 @@
1+
#ifndef __PARALLEL_DEVICE_H__
2+
#define __PARALLEL_DEVICE_H__
13
#ifdef __MPI
24
#include "mpi.h"
35
#include "module_base/module_device/device.h"
6+
#include "module_base/module_device/memory_op.h"
47
#include <complex>
5-
#include <string>
6-
#include <vector>
78
namespace Parallel_Common
89
{
9-
void bcast_complex(std::complex<double>* object, const int& n, const MPI_Comm& comm)
10-
{
11-
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
12-
}
13-
void bcast_complex(std::complex<float>* object, const int& n, const MPI_Comm& comm)
14-
{
15-
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
16-
}
17-
void bcast_real(double* object, const int& n, const MPI_Comm& comm)
18-
{
19-
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
20-
}
21-
void bcast_real(float* object, const int& n, const MPI_Comm& comm)
22-
{
23-
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
24-
}
10+
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
11+
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
12+
void bcast_data(double* object, const int& n, const MPI_Comm& comm);
13+
void bcast_data(float* object, const int& n, const MPI_Comm& comm);
14+
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
15+
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
16+
void reduce_data(double* object, const int& n, const MPI_Comm& comm);
17+
void reduce_data(float* object, const int& n, const MPI_Comm& comm);
2518

26-
template <typename T, typename Device>
2719
/**
28-
* @brief bcast complex in Device
20+
* @brief bcast data in Device
2921
*
22+
* @tparam T: float, double, std::complex<float>, std::complex<double>
23+
* @tparam Device
3024
* @param ctx Device ctx
3125
* @param object complex arrays in Device
3226
* @param n the size of complex arrays
3327
* @param comm MPI_Comm
3428
* @param tmp_space tmp space in CPU
3529
*/
36-
void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
30+
template <typename T, typename Device>
31+
void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
3732
{
3833
const base_device::DEVICE_CPU* cpu_ctx = {};
3934
T* object_cpu = nullptr;
@@ -56,7 +51,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
5651
object_cpu = object;
5752
}
5853

59-
bcast_complex(object_cpu, n, comm);
54+
bcast_data(object_cpu, n, comm);
6055

6156
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
6257
{
@@ -70,7 +65,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
7065
}
7166

7267
template <typename T, typename Device>
73-
void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
68+
void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
7469
{
7570
const base_device::DEVICE_CPU* cpu_ctx = {};
7671
T* object_cpu = nullptr;
@@ -93,7 +88,7 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
9388
object_cpu = object;
9489
}
9590

96-
bcast_real(object_cpu, n, comm);
91+
reduce_data(object_cpu, n, comm);
9792

9893
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
9994
{
@@ -105,7 +100,9 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
105100
}
106101
return;
107102
}
103+
108104
}
109105

110106

107+
#endif
111108
#endif

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
1414
ModuleBase::TITLE(this->classname, "psiToRho");
1515
ModuleBase::timer::tick(this->classname, "psiToRho");
1616
const int nspin = PARAM.inp.nspin;
17+
for (int is = 0; is < nspin; is++)
18+
{
19+
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
20+
}
1721

1822
if (GlobalV::MY_STOGROUP == 0)
1923
{
20-
for (int is = 0; is < nspin; is++)
21-
{
22-
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
23-
}
24-
2524
for (int ik = 0; ik < psi.get_nk(); ++ik)
2625
{
2726
psi.fix_k(ik);

0 commit comments

Comments
 (0)