diff --git a/source/module_hamilt_lcao/module_gint/gint_force_gpu.cu b/source/module_hamilt_lcao/module_gint/gint_force_gpu.cu index 9ab974660f..ac7f9e89c8 100644 --- a/source/module_hamilt_lcao/module_gint/gint_force_gpu.cu +++ b/source/module_hamilt_lcao/module_gint/gint_force_gpu.cu @@ -31,11 +31,7 @@ void gint_fvl_gpu(const hamilt::HContainer* dm, const Grid_Technique& gridt, const UnitCell& ucell) { -#ifdef __MPI - const int dev_id = base_device::information::set_device_by_rank(); -#else - const int dev_id = 0; -#endif + checkCuda(cudaSetDevice(gridt.dev_id)); // checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); const int nbzp = gridt.nbzp; @@ -99,7 +95,7 @@ void gint_fvl_gpu(const hamilt::HContainer* dm, { // 20240620 Note that it must be set again here because // cuda's device is not safe in a multi-threaded environment. - checkCuda(cudaSetDevice(dev_id)); + checkCuda(cudaSetDevice(gridt.dev_id)); const int sid = omp_get_thread_num(); int max_m = 0; diff --git a/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu b/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu index 2093105c58..4b18d50438 100644 --- a/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu +++ b/source/module_hamilt_lcao/module_gint/gint_rho_gpu.cu @@ -17,11 +17,7 @@ void gint_rho_gpu(const hamilt::HContainer* dm, const UnitCell& ucell, double* rho) { -#ifdef __MPI - const int dev_id = base_device::information::set_device_by_rank(); -#else - const int dev_id = 0; -#endif + checkCuda(cudaSetDevice(gridt.dev_id)); // checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); const int nbzp = gridt.nbzp; @@ -80,7 +76,7 @@ void gint_rho_gpu(const hamilt::HContainer* dm, // 20240620 Note that it must be set again here because // cuda's device is not safe in a multi-threaded environment. - checkCuda(cudaSetDevice(dev_id)); + checkCuda(cudaSetDevice(gridt.dev_id)); // get stream id const int sid = omp_get_thread_num(); diff --git a/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu b/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu index 40d06dd186..a406fd24fe 100644 --- a/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu +++ b/source/module_hamilt_lcao/module_gint/gint_vl_gpu.cu @@ -29,11 +29,7 @@ void gint_vl_gpu(hamilt::HContainer* hRGint, double* pvpR, const bool is_gamma_only) { -#ifdef __MPI - const int dev_id = base_device::information::set_device_by_rank(); -#else - const int dev_id = 0; -#endif + checkCuda(cudaSetDevice(gridt.dev_id)); // checkCuda(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); const int nbzp = gridt.nbzp; const int num_streams = gridt.nstreams; @@ -81,7 +77,7 @@ void gint_vl_gpu(hamilt::HContainer* hRGint, { // 20240620 Note that it must be set again here because // cuda's device is not safe in a multi-threaded environment. - checkCuda(cudaSetDevice(dev_id)); + checkCuda(cudaSetDevice(gridt.dev_id)); const int sid = omp_get_thread_num(); int max_m = 0; diff --git a/source/module_hamilt_lcao/module_gint/grid_technique.cpp b/source/module_hamilt_lcao/module_gint/grid_technique.cpp index c66c47a24f..42d441e3ec 100644 --- a/source/module_hamilt_lcao/module_gint/grid_technique.cpp +++ b/source/module_hamilt_lcao/module_gint/grid_technique.cpp @@ -562,7 +562,7 @@ int Grid_Technique::find_offset(const int id1, const int id2, const int iat1, co void Grid_Technique::init_gpu_gint_variables(const UnitCell& ucell, const int num_stream) { #ifdef __MPI - base_device::information::set_device_by_rank(); + dev_id = base_device::information::set_device_by_rank(); #endif if (is_malloced) { free_gpu_gint_variables(this->nat); diff --git a/source/module_hamilt_lcao/module_gint/grid_technique.h b/source/module_hamilt_lcao/module_gint/grid_technique.h index ec37922765..cffcc99aaa 100644 --- a/source/module_hamilt_lcao/module_gint/grid_technique.h +++ b/source/module_hamilt_lcao/module_gint/grid_technique.h @@ -177,6 +177,7 @@ class Grid_Technique : public Grid_MeshBall { double* rcut_g; double*mcell_pos_g; + int dev_id = 0; int nstreams = 4; // streams[nstreams] // TODO it needs to be implemented through configuration files