Skip to content

Commit f1bfe91

Browse files
authored
Fix: results not stable when using CUDA AWare MPI (#5976)
1 parent 648501d commit f1bfe91

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

source/module_base/para_gemm.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
256256
int m = colA_loc[ip];
257257
int size = m * LDA;
258258
MPI_Status status;
259+
#ifdef __CUDA_MPI
260+
// If the memory is not set to zero, it may cause the result to be wrong when using CUDA Aware MPI
261+
// I am not sure if it is due to CUDA Aware MPI or not
262+
base_device::memory::set_memory_op<T, Device>()(Atmp_device, 0, size);
263+
#endif
259264
Parallel_Common::recv_dev<T, Device>(Atmp_device, size, ip, 0, col_world, &status, A_tmp_.data());
260265
MPI_Wait(&requests[ip], &status);
261266
ModuleBase::gemm_op<T, Device>()('C',

source/module_hsolver/para_linear_transform.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
120120

121121
int size = LDA * ncolA_ip;
122122
MPI_Status status;
123+
#ifdef __CUDA_MPI
124+
// If the memory is not set to zero, it may cause the result to be wrong when using CUDA Aware MPI
125+
// I am not sure if it is due to CUDA Aware MPI or not
126+
base_device::memory::set_memory_op<T, Device>()(Atmp_device, 0, size);
127+
#endif
123128
Parallel_Common::recv_dev<T, Device>(Atmp_device, size, ip, 0, col_world, &status, A_tmp_.data());
124129
ModuleBase::gemm_op<T, Device>()('N',
125130
'N',

0 commit comments

Comments
 (0)