1+ #ifndef __PARALLEL_DEVICE_H__
2+ #define __PARALLEL_DEVICE_H__
13#ifdef __MPI
24#include " mpi.h"
35#include " module_base/module_device/device.h"
6+ #include " module_base/module_device/memory_op.h"
47#include < complex>
5- #include < string>
6- #include < vector>
78namespace Parallel_Common
89{
9- void bcast_complex (std::complex <double >* object, const int & n, const MPI_Comm& comm)
10- {
11- MPI_Bcast (object, n * 2 , MPI_DOUBLE, 0 , comm);
12- }
13- void bcast_complex (std::complex <float >* object, const int & n, const MPI_Comm& comm)
14- {
15- MPI_Bcast (object, n * 2 , MPI_FLOAT, 0 , comm);
16- }
17- void bcast_real (double * object, const int & n, const MPI_Comm& comm)
18- {
19- MPI_Bcast (object, n, MPI_DOUBLE, 0 , comm);
20- }
21- void bcast_real (float * object, const int & n, const MPI_Comm& comm)
22- {
23- MPI_Bcast (object, n, MPI_FLOAT, 0 , comm);
24- }
10+ void bcast_data (std::complex <double >* object, const int & n, const MPI_Comm& comm);
11+ void bcast_data (std::complex <float >* object, const int & n, const MPI_Comm& comm);
12+ void bcast_data (double * object, const int & n, const MPI_Comm& comm);
13+ void bcast_data (float * object, const int & n, const MPI_Comm& comm);
14+ void reduce_data (std::complex <double >* object, const int & n, const MPI_Comm& comm);
15+ void reduce_data (std::complex <float >* object, const int & n, const MPI_Comm& comm);
16+ void reduce_data (double * object, const int & n, const MPI_Comm& comm);
17+ void reduce_data (float * object, const int & n, const MPI_Comm& comm);
2518
26- template <typename T, typename Device>
2719/* *
28- * @brief bcast complex in Device
20+ * @brief bcast data in Device
2921 *
22+ * @tparam T: float, double, std::complex<float>, std::complex<double>
23+ * @tparam Device
3024 * @param ctx Device ctx
3125 * @param object complex arrays in Device
3226 * @param n the size of complex arrays
3327 * @param comm MPI_Comm
3428 * @param tmp_space tmp space in CPU
3529 */
36- void bcast_complex (const Device* ctx, T* object, const int & n, const MPI_Comm& comm, T* tmp_space = nullptr )
30+ template <typename T, typename Device>
31+ void bcast_dev (const Device* ctx, T* object, const int & n, const MPI_Comm& comm, T* tmp_space = nullptr )
3732{
3833 const base_device::DEVICE_CPU* cpu_ctx = {};
3934 T* object_cpu = nullptr ;
@@ -56,7 +51,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
5651 object_cpu = object;
5752 }
5853
59- bcast_complex (object_cpu, n, comm);
54+ bcast_data (object_cpu, n, comm);
6055
6156 if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
6257 {
@@ -70,7 +65,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
7065}
7166
7267template <typename T, typename Device>
73- void bcast_real (const Device* ctx, T* object, const int & n, const MPI_Comm& comm, T* tmp_space = nullptr )
68+ void reduce_dev (const Device* ctx, T* object, const int & n, const MPI_Comm& comm, T* tmp_space = nullptr )
7469{
7570 const base_device::DEVICE_CPU* cpu_ctx = {};
7671 T* object_cpu = nullptr ;
@@ -93,7 +88,7 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
9388 object_cpu = object;
9489 }
9590
96- bcast_real (object_cpu, n, comm);
91+ reduce_data (object_cpu, n, comm);
9792
9893 if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
9994 {
@@ -105,7 +100,9 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
105100 }
106101 return ;
107102}
103+
108104}
109105
110106
107+ #endif
111108#endif
0 commit comments