diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp index abfacf5154..79fdbf0de7 100644 --- a/source/module_base/para_gemm.cpp +++ b/source/module_base/para_gemm.cpp @@ -256,6 +256,11 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con int m = colA_loc[ip]; int size = m * LDA; MPI_Status status; +#ifdef __CUDA_MPI + // If the memory is not set to zero, it may cause the result to be wrong when using CUDA Aware MPI + // I am not sure if it is due to CUDA Aware MPI or not + base_device::memory::set_memory_op()(Atmp_device, 0, size); +#endif Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp_.data()); MPI_Wait(&requests[ip], &status); ModuleBase::gemm_op()('C', diff --git a/source/module_hsolver/para_linear_transform.cpp b/source/module_hsolver/para_linear_transform.cpp index 5a6c8def27..d611120346 100644 --- a/source/module_hsolver/para_linear_transform.cpp +++ b/source/module_hsolver/para_linear_transform.cpp @@ -120,6 +120,11 @@ void PLinearTransform::act(const T alpha, const T* A, const T* U, con int size = LDA * ncolA_ip; MPI_Status status; +#ifdef __CUDA_MPI + // If the memory is not set to zero, it may cause the result to be wrong when using CUDA Aware MPI + // I am not sure if it is due to CUDA Aware MPI or not + base_device::memory::set_memory_op()(Atmp_device, 0, size); +#endif Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp_.data()); ModuleBase::gemm_op()('N', 'N',