@@ -24,55 +24,59 @@ void PGemmCN<T, Device>::set_dimension(
2424 const int ncolB_in,
2525 const int LDB_in,
2626 const int nrow_in,
27- const int LDC_global_in)
27+ const int LDC_in,
28+ const bool gatherC_in)
2829{
2930#ifdef __MPI
3031 MPI_Comm_rank (comm_col, &col_rank);
3132 MPI_Comm_size (comm_col, &col_nproc);
3233 if (comm_row != MPI_COMM_NULL)
3334 {
3435 MPI_Comm_rank (comm_row, &row_rank);
36+ MPI_Comm_size (comm_row, &row_nproc);
3537 }
3638 col_world = comm_col;
3739 row_world = comm_row;
3840#endif
3941 this ->LDA = LDA_in;
4042 this ->LDB = LDB_in;
41- this ->LDC_global = LDC_global_in ;
43+ this ->LDC = LDC_in ;
4244 this ->ncolA = ncolA_in;
4345 this ->ncolB = ncolB_in;
4446 this ->nrow = nrow_in;
4547#ifdef __MPI
48+ this ->gatherC = gatherC_in;
4649 colA_loc.resize (col_nproc);
47- colB_loc.resize (col_nproc);
48- row_loc.resize (col_nproc);
49- recv_counts.resize (col_nproc);
50- displs.resize (col_nproc);
51- requests.resize (col_nproc);
5250 MPI_Allgather (&ncolA, 1 , MPI_INT, colA_loc.data (), 1 , MPI_INT, col_world);
53- MPI_Allgather (&ncolB, 1 , MPI_INT, colB_loc.data (), 1 , MPI_INT, col_world);
54- MPI_Allgather (&nrow, 1 , MPI_INT, row_loc.data (), 1 , MPI_INT, col_world);
5551 for (int ip = 0 ; ip < col_nproc; ip++)
5652 {
5753 max_colA = std::max (max_colA, colA_loc[ip]);
5854 }
5955
60- for (int ip = 0 ; ip < col_nproc; ip++)
61- {
62- recv_counts[ip] = LDC_global * colB_loc[ip];
63- }
64- displs[0 ] = 0 ;
65- for (int ip = 1 ; ip < col_nproc; ip++)
56+ if (this ->gatherC )
6657 {
67- displs[ip] = displs[ip - 1 ] + recv_counts[ip - 1 ];
58+ colB_loc.resize (col_nproc);
59+ recv_counts.resize (col_nproc);
60+ displs.resize (col_nproc);
61+ requests.resize (col_nproc);
62+ MPI_Allgather (&ncolB, 1 , MPI_INT, colB_loc.data (), 1 , MPI_INT, col_world);
63+ for (int ip = 0 ; ip < col_nproc; ip++)
64+ {
65+ recv_counts[ip] = LDC * colB_loc[ip];
66+ }
67+ displs[0 ] = 0 ;
68+ for (int ip = 1 ; ip < col_nproc; ip++)
69+ {
70+ displs[ip] = displs[ip - 1 ] + recv_counts[ip - 1 ];
71+ }
72+ size_C_global = displs[col_nproc - 1 ] + recv_counts[col_nproc - 1 ];
6873 }
69- size_C_global = displs[col_nproc - 1 ] + recv_counts[col_nproc - 1 ];
70- send_counts = ncolB * LDC_global;
74+ size_C_local = ncolB * LDC;
7175#endif
7276}
7377
7478template <typename T, typename Device>
75- void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C_global )
79+ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C )
7680{
7781 const Device* ctx = {};
7882#ifdef __MPI
@@ -88,20 +92,23 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
8892 }
8993 }
9094
91- std::vector<T> C_tmp (send_counts);
92- T* Ctmp_device = nullptr ;
93- if (std::is_same<Device, base_device::DEVICE_GPU>::value)
94- {
95- resmem_dev_op ()(Ctmp_device, send_counts);
96- }
97- else
95+ T* C_local = C;
96+ std::vector<T> C_tmp;
97+ if (this ->gatherC )
9898 {
99- Ctmp_device = C_tmp.data ();
99+ C_tmp.resize (size_C_local);
100+ if (std::is_same<Device, base_device::DEVICE_GPU>::value)
101+ {
102+ C_local = nullptr ;
103+ resmem_dev_op ()(C_local, size_C_local);
104+ }
105+ else
106+ {
107+ C_local = C_tmp.data ();
108+ }
109+ syncmem_dev_op ()(C_local, C + displs[col_rank], size_C_local);
100110 }
101111
102- T* C_local = C_global + displs[col_rank];
103- syncmem_dev_op ()(Ctmp_device, C_local, send_counts);
104-
105112 T* Atmp_device = nullptr ;
106113 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
107114 {
@@ -116,7 +123,7 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
116123 T real_beta = row_rank == 0 ? beta : 0 ;
117124 for (int ip = 0 ; ip < col_nproc; ip++)
118125 {
119- T* C_start = Ctmp_device + shift;
126+ T* C_start = C_local + shift;
120127 if (col_rank == ip)
121128 {
122129 ModuleBase::gemm_op<T, Device>()(ctx,
@@ -132,7 +139,7 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
132139 LDB,
133140 &real_beta,
134141 C_start,
135- LDC_global );
142+ LDC );
136143 shift += ncolA;
137144 }
138145 else
@@ -155,61 +162,65 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
155162 LDB,
156163 &real_beta,
157164 C_start,
158- LDC_global );
165+ LDC );
159166 shift += m;
160167 }
161168 }
162169
163- T* Cglobal_cpu = nullptr ;
164- if (std::is_same<Device, base_device::DEVICE_GPU>::value)
170+ if (this ->gatherC )
165171 {
166- delmem_dev_op ()(Ctmp_device);
167- delmem_dev_op ()(Atmp_device);
168- syncmem_dev_op ()(C_tmp.data (), Ctmp_device, send_counts);
169- resmem_dev_op ()(Cglobal_cpu, size_C_global);
172+ T* Cglobal_cpu = nullptr ;
173+ T* Clocal_cpu = C_tmp.data ();;
174+ if (std::is_same<Device, base_device::DEVICE_GPU>::value)
175+ {
176+ delmem_dev_op ()(Atmp_device);
177+
178+ syncmem_d2h_op ()(Clocal_cpu, C_local, size_C_local);
179+ delmem_dev_op ()(C_local);
180+
181+ resmem_dev_op ()(Cglobal_cpu, size_C_global);
182+ }
183+ else
184+ {
185+ Cglobal_cpu = C;
186+ }
187+ if (this ->row_nproc > 1 )
188+ {
189+ Parallel_Common::reduce_data (Clocal_cpu, size_C_local, row_world);
190+ }
191+ Parallel_Common::gatherv_data (Clocal_cpu,
192+ size_C_local,
193+ Cglobal_cpu,
194+ recv_counts.data (),
195+ displs.data (),
196+ col_world);
197+
198+ if (std::is_same<Device, base_device::DEVICE_GPU>::value)
199+ {
200+ syncmem_h2d_op ()(C, Cglobal_cpu, size_C_global);
201+ delmem_dev_op ()(Cglobal_cpu);
202+ }
170203 }
171204 else
172205 {
173- Cglobal_cpu = C_global;
174- }
175- Parallel_Common::gatherv_data (C_tmp.data (),
176- send_counts,
177- Cglobal_cpu,
178- recv_counts.data (),
179- displs.data (),
180- col_world);
181- if (row_world != MPI_COMM_NULL)
182- {
183- Parallel_Common::reduce_data (Cglobal_cpu, size_C_global, row_world);
184- }
185- if (std::is_same<Device, base_device::DEVICE_GPU>::value)
186- {
187- syncmem_dev_op ()(C_global, Cglobal_cpu, size_C_global);
188- delmem_dev_op ()(Cglobal_cpu);
206+ if (this ->row_nproc > 1 )
207+ {
208+ Parallel_Common::reduce_dev<T, Device>(C, size_C_local, row_world);
209+ }
189210 }
190211 }
191212 else
192213 {
193214 T real_beta = row_rank == 0 ? beta : 0 ;
194215#else
195- T real_beta = beta;
216+ T real_beta = beta;
196217#endif
197- ModuleBase::gemm_op<T, Device>()(ctx,
198- ' C' ,
199- ' N' ,
200- ncolA,
201- ncolB,
202- nrow,
203- &alpha,
204- A,
205- LDA,
206- B,
207- LDB,
208- &real_beta,
209- C_global,
210- LDC_global);
218+ ModuleBase::gemm_op<T, Device>()(ctx, ' C' , ' N' , ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC);
211219#ifdef __MPI
212- Parallel_Common::reduce_dev<T, Device>(C_global, size_C_global, row_world);
220+ if (this ->row_nproc > 1 )
221+ {
222+ Parallel_Common::reduce_dev<T, Device>(C, size_C_local, row_world);
223+ }
213224 }
214225#endif
215226}
0 commit comments