Skip to content

Commit 39df3b9

Browse files
authored
fix a deadlock bug (#5210)
1 parent 66c4e58 commit 39df3b9

File tree

5 files changed

+8
-19
lines changed

5 files changed

+8
-19
lines changed

source/module_hamilt_lcao/module_gint/gint_force_gpu.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ void gint_fvl_gpu(const hamilt::HContainer<double>* dm,
3131
const Grid_Technique& gridt,
3232
const UnitCell& ucell)
3333
{
34-
#ifdef __MPI
35-
const int dev_id = base_device::information::set_device_by_rank();
36-
#else
37-
const int dev_id = 0;
38-
#endif
34+
checkCuda(cudaSetDevice(gridt.dev_id));
3935
// checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
4036

4137
const int nbzp = gridt.nbzp;
@@ -99,7 +95,7 @@ void gint_fvl_gpu(const hamilt::HContainer<double>* dm,
9995
{
10096
// 20240620 Note that it must be set again here because
10197
// cuda's device is not safe in a multi-threaded environment.
102-
checkCuda(cudaSetDevice(dev_id));
98+
checkCuda(cudaSetDevice(gridt.dev_id));
10399
const int sid = omp_get_thread_num();
104100

105101
int max_m = 0;

source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@ void gint_rho_gpu(const hamilt::HContainer<double>* dm,
1717
const UnitCell& ucell,
1818
double* rho)
1919
{
20-
#ifdef __MPI
21-
const int dev_id = base_device::information::set_device_by_rank();
22-
#else
23-
const int dev_id = 0;
24-
#endif
20+
checkCuda(cudaSetDevice(gridt.dev_id));
2521
// checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
2622

2723
const int nbzp = gridt.nbzp;
@@ -80,7 +76,7 @@ void gint_rho_gpu(const hamilt::HContainer<double>* dm,
8076
// 20240620 Note that it must be set again here because
8177
// cuda's device is not safe in a multi-threaded environment.
8278

83-
checkCuda(cudaSetDevice(dev_id));
79+
checkCuda(cudaSetDevice(gridt.dev_id));
8480
// get stream id
8581
const int sid = omp_get_thread_num();
8682

source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ void gint_vl_gpu(hamilt::HContainer<double>* hRGint,
2929
double* pvpR,
3030
const bool is_gamma_only)
3131
{
32-
#ifdef __MPI
33-
const int dev_id = base_device::information::set_device_by_rank();
34-
#else
35-
const int dev_id = 0;
36-
#endif
32+
checkCuda(cudaSetDevice(gridt.dev_id));
3733
// checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
3834
const int nbzp = gridt.nbzp;
3935
const int num_streams = gridt.nstreams;
@@ -81,7 +77,7 @@ void gint_vl_gpu(hamilt::HContainer<double>* hRGint,
8177
{
8278
// 20240620 Note that it must be set again here because
8379
// cuda's device is not safe in a multi-threaded environment.
84-
checkCuda(cudaSetDevice(dev_id));
80+
checkCuda(cudaSetDevice(gridt.dev_id));
8581
const int sid = omp_get_thread_num();
8682

8783
int max_m = 0;

source/module_hamilt_lcao/module_gint/grid_technique.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ int Grid_Technique::find_offset(const int id1, const int id2, const int iat1, co
562562
void Grid_Technique::init_gpu_gint_variables(const UnitCell& ucell,
563563
const int num_stream) {
564564
#ifdef __MPI
565-
base_device::information::set_device_by_rank();
565+
dev_id = base_device::information::set_device_by_rank();
566566
#endif
567567
if (is_malloced) {
568568
free_gpu_gint_variables(this->nat);

source/module_hamilt_lcao/module_gint/grid_technique.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class Grid_Technique : public Grid_MeshBall {
177177
double* rcut_g;
178178
double*mcell_pos_g;
179179

180+
int dev_id = 0;
180181
int nstreams = 4;
181182
// streams[nstreams]
182183
// TODO it needs to be implemented through configuration files

0 commit comments

Comments
 (0)