Skip to content

Commit bc7b92c

Browse files
committed
add gatherC for para_gemm
1 parent 1f26646 commit bc7b92c

File tree

3 files changed

+154
-108
lines changed

3 files changed

+154
-108
lines changed

source/module_base/para_gemm.cpp

Lines changed: 82 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -24,55 +24,59 @@ void PGemmCN<T, Device>::set_dimension(
2424
const int ncolB_in,
2525
const int LDB_in,
2626
const int nrow_in,
27-
const int LDC_global_in)
27+
const int LDC_in,
28+
const bool gatherC_in)
2829
{
2930
#ifdef __MPI
3031
MPI_Comm_rank(comm_col, &col_rank);
3132
MPI_Comm_size(comm_col, &col_nproc);
3233
if (comm_row != MPI_COMM_NULL)
3334
{
3435
MPI_Comm_rank(comm_row, &row_rank);
36+
MPI_Comm_size(comm_row, &row_nproc);
3537
}
3638
col_world = comm_col;
3739
row_world = comm_row;
3840
#endif
3941
this->LDA = LDA_in;
4042
this->LDB = LDB_in;
41-
this->LDC_global = LDC_global_in;
43+
this->LDC = LDC_in;
4244
this->ncolA = ncolA_in;
4345
this->ncolB = ncolB_in;
4446
this->nrow = nrow_in;
4547
#ifdef __MPI
48+
this->gatherC = gatherC_in;
4649
colA_loc.resize(col_nproc);
47-
colB_loc.resize(col_nproc);
48-
row_loc.resize(col_nproc);
49-
recv_counts.resize(col_nproc);
50-
displs.resize(col_nproc);
51-
requests.resize(col_nproc);
5250
MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world);
53-
MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world);
54-
MPI_Allgather(&nrow, 1, MPI_INT, row_loc.data(), 1, MPI_INT, col_world);
5551
for (int ip = 0; ip < col_nproc; ip++)
5652
{
5753
max_colA = std::max(max_colA, colA_loc[ip]);
5854
}
5955

60-
for (int ip = 0; ip < col_nproc; ip++)
61-
{
62-
recv_counts[ip] = LDC_global * colB_loc[ip];
63-
}
64-
displs[0] = 0;
65-
for (int ip = 1; ip < col_nproc; ip++)
56+
if (this->gatherC)
6657
{
67-
displs[ip] = displs[ip - 1] + recv_counts[ip - 1];
58+
colB_loc.resize(col_nproc);
59+
recv_counts.resize(col_nproc);
60+
displs.resize(col_nproc);
61+
requests.resize(col_nproc);
62+
MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world);
63+
for (int ip = 0; ip < col_nproc; ip++)
64+
{
65+
recv_counts[ip] = LDC * colB_loc[ip];
66+
}
67+
displs[0] = 0;
68+
for (int ip = 1; ip < col_nproc; ip++)
69+
{
70+
displs[ip] = displs[ip - 1] + recv_counts[ip - 1];
71+
}
72+
size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1];
6873
}
69-
size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1];
70-
send_counts = ncolB * LDC_global;
74+
size_C_local = ncolB * LDC;
7175
#endif
7276
}
7377

7478
template <typename T, typename Device>
75-
void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C_global)
79+
void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C)
7680
{
7781
const Device* ctx = {};
7882
#ifdef __MPI
@@ -88,20 +92,23 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
8892
}
8993
}
9094

91-
std::vector<T> C_tmp(send_counts);
92-
T* Ctmp_device = nullptr;
93-
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
94-
{
95-
resmem_dev_op()(Ctmp_device, send_counts);
96-
}
97-
else
95+
T* C_local = C;
96+
std::vector<T> C_tmp;
97+
if (this->gatherC)
9898
{
99-
Ctmp_device = C_tmp.data();
99+
C_tmp.resize(size_C_local);
100+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
101+
{
102+
C_local = nullptr;
103+
resmem_dev_op()(C_local, size_C_local);
104+
}
105+
else
106+
{
107+
C_local = C_tmp.data();
108+
}
109+
syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local);
100110
}
101111

102-
T* C_local = C_global + displs[col_rank];
103-
syncmem_dev_op()(Ctmp_device, C_local, send_counts);
104-
105112
T* Atmp_device = nullptr;
106113
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
107114
{
@@ -116,7 +123,7 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
116123
T real_beta = row_rank == 0 ? beta : 0;
117124
for (int ip = 0; ip < col_nproc; ip++)
118125
{
119-
T* C_start = Ctmp_device + shift;
126+
T* C_start = C_local + shift;
120127
if (col_rank == ip)
121128
{
122129
ModuleBase::gemm_op<T, Device>()(ctx,
@@ -132,7 +139,7 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
132139
LDB,
133140
&real_beta,
134141
C_start,
135-
LDC_global);
142+
LDC);
136143
shift += ncolA;
137144
}
138145
else
@@ -155,61 +162,65 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
155162
LDB,
156163
&real_beta,
157164
C_start,
158-
LDC_global);
165+
LDC);
159166
shift += m;
160167
}
161168
}
162169

163-
T* Cglobal_cpu = nullptr;
164-
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
170+
if (this->gatherC)
165171
{
166-
delmem_dev_op()(Ctmp_device);
167-
delmem_dev_op()(Atmp_device);
168-
syncmem_dev_op()(C_tmp.data(), Ctmp_device, send_counts);
169-
resmem_dev_op()(Cglobal_cpu, size_C_global);
172+
T* Cglobal_cpu = nullptr;
173+
T* Clocal_cpu = C_tmp.data();;
174+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
175+
{
176+
delmem_dev_op()(Atmp_device);
177+
178+
syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local);
179+
delmem_dev_op()(C_local);
180+
181+
resmem_dev_op()(Cglobal_cpu, size_C_global);
182+
}
183+
else
184+
{
185+
Cglobal_cpu = C;
186+
}
187+
if (this->row_nproc > 1)
188+
{
189+
Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world);
190+
}
191+
Parallel_Common::gatherv_data(Clocal_cpu,
192+
size_C_local,
193+
Cglobal_cpu,
194+
recv_counts.data(),
195+
displs.data(),
196+
col_world);
197+
198+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
199+
{
200+
syncmem_h2d_op()(C, Cglobal_cpu, size_C_global);
201+
delmem_dev_op()(Cglobal_cpu);
202+
}
170203
}
171204
else
172205
{
173-
Cglobal_cpu = C_global;
174-
}
175-
Parallel_Common::gatherv_data(C_tmp.data(),
176-
send_counts,
177-
Cglobal_cpu,
178-
recv_counts.data(),
179-
displs.data(),
180-
col_world);
181-
if (row_world != MPI_COMM_NULL)
182-
{
183-
Parallel_Common::reduce_data(Cglobal_cpu, size_C_global, row_world);
184-
}
185-
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
186-
{
187-
syncmem_dev_op()(C_global, Cglobal_cpu, size_C_global);
188-
delmem_dev_op()(Cglobal_cpu);
206+
if (this->row_nproc > 1)
207+
{
208+
Parallel_Common::reduce_dev<T, Device>(C, size_C_local, row_world);
209+
}
189210
}
190211
}
191212
else
192213
{
193214
T real_beta = row_rank == 0 ? beta : 0;
194215
#else
195-
T real_beta = beta;
216+
T real_beta = beta;
196217
#endif
197-
ModuleBase::gemm_op<T, Device>()(ctx,
198-
'C',
199-
'N',
200-
ncolA,
201-
ncolB,
202-
nrow,
203-
&alpha,
204-
A,
205-
LDA,
206-
B,
207-
LDB,
208-
&real_beta,
209-
C_global,
210-
LDC_global);
218+
ModuleBase::gemm_op<T, Device>()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC);
211219
#ifdef __MPI
212-
Parallel_Common::reduce_dev<T, Device>(C_global, size_C_global, row_world);
220+
if (this->row_nproc > 1)
221+
{
222+
Parallel_Common::reduce_dev<T, Device>(C, size_C_local, row_world);
223+
}
213224
}
214225
#endif
215226
}

source/module_base/para_gemm.h

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ namespace ModuleBase
1212
{
1313
/**
1414
* @brief this class is used to perform parallel matrix multiplication
15-
* C_global = alpha * A^+ * B + beta * C_global
16-
* Here, A and B are local matrices in each proc, and C_global is a global matrix gathered from all procs
17-
* All procs have their own C_global matrix with the same values.
15+
* C = alpha * A^H * B + beta * C
16+
* Here, A and B are local matrices in each proc,
17+
* C can be C_local or C_global, depending on the value of gatherC
18+
* C_local is a local matrix in each proc
19+
* C_global is a global matrix gathered from all procs and all procs have their own C_global matrix with the same
20+
* C_global and C_local have the same LDC, but different column numbers
21+
* values.
1822
*/
1923
template <typename T, typename Device = base_device::DEVICE_CPU>
2024
class PGemmCN
@@ -24,14 +28,15 @@ class PGemmCN
2428
~PGemmCN();
2529

2630
/**
27-
* @brief set the dimension of A, B, and C_global
31+
* @brief set the dimension of A, B, and C
2832
*
2933
* @param ncolA number of columns of A, which is a local matrix in each proc
3034
* @param LDA leading dimension of A in each proc
3135
* @param ncolB number of columns of B, which is a local matrix in each proc
3236
* @param LDB leading dimension of B in each proc
3337
* @param nrow number of rows of A or B
34-
* @param LDC_global leading dimension of C_global, which is the global C matrix gathered from all procs
38+
* @param LDC leading dimension of C. C can be C_local or C_global
39+
* @param gatherC whether gather C_local to C_global
3540
*/
3641
void set_dimension(
3742
#ifdef __MPI
@@ -43,47 +48,46 @@ class PGemmCN
4348
const int ncolB,
4449
const int LDB,
4550
const int nrow,
46-
const int LDC_global);
51+
const int LDC,
52+
const bool gatherC = true);
53+
4754
/**
48-
* @brief calculate C_global = alpha * A^+ * B + beta * C_global
55+
* @brief calculate C = alpha * A^H * B + beta * C
4956
*
50-
* @param alpha
51-
* @param A
52-
* @param B
53-
* @param beta
54-
* @param C_global
5557
*/
56-
void multiply(const T alpha, const T* A, const T* B, const T beta, T* C_global);
58+
void multiply(const T alpha, const T* A, const T* B, const T beta, T* C);
5759
#ifdef __MPI
5860
MPI_Comm col_world = MPI_COMM_NULL; ///< column communicator world
5961
MPI_Comm row_world = MPI_COMM_NULL; ///< row communicator world
6062

6163
int col_rank = 0; ///< rank in col_world
6264
int col_nproc = 1; ///< number of procs in col_world
6365
int row_rank = 0; ///< rank in row_world
66+
int row_nproc = 1; ///< number of procs in row_world
6467

6568
std::vector<int> colA_loc; ///< [col_nproc] number of columns of A matrix in each proc
6669
int max_colA = 0; ///< maximum number of columns of A matrix in all procs
6770
std::vector<int> colB_loc; ///<[col_nproc] number of columns of B matrix in each proc
68-
std::vector<int> row_loc; ///<[col_nproc] number of rows of C matrix in each proc
6971

7072
std::vector<MPI_Request> requests; ///< MPI request
7173
std::vector<int> recv_counts; ///< receive counts for gathering C_local to C_global
7274
std::vector<int> displs; ///< displacements for gathering C_local to C_global
73-
int send_counts = 0; ///< send counts for gathering C_local to C_global
75+
int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc
7476
int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs
77+
bool gatherC = true; ///< whether gather C_local to C_global
7578
#endif
76-
int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc
77-
int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc
78-
int nrow = 0; ///< number of rows of A or B
79-
int LDA = 0; ///< leading dimension of A in each proc
80-
int LDB = 0; ///< leading dimension of B in each proc
81-
int LDC_global = 0; ///< leading dimension of C_global, which is the global C matrix gathered from all procs
79+
int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc
80+
int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc
81+
int nrow = 0; ///< number of rows of A or B
82+
int LDA = 0; ///< leading dimension of A in each proc
83+
int LDB = 0; ///< leading dimension of B in each proc
84+
int LDC = 0; ///< leading dimension of C, which can be C_local or C_global
8285
private:
83-
using resmem_dev_op = base_device::memory::resize_memory_op<T, Device>;
86+
using resmem_dev_op = base_device::memory::resize_memory_op<T, Device>;
8487
using delmem_dev_op = base_device::memory::delete_memory_op<T, Device>;
8588
using syncmem_dev_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
86-
89+
using syncmem_d2h_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, Device>;
90+
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
8791
};
8892
} // namespace ModuleBase
8993
#endif

0 commit comments

Comments
 (0)