Skip to content

Commit 23fc25e

Browse files
committed
add para_linear_transform_op
1 parent fcef6cd commit 23fc25e

File tree

6 files changed

+426
-3
lines changed

6 files changed

+426
-3
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ OBJS_HSOLVER=diago_cg.o\
333333
diago_david.o\
334334
diago_dav_subspace.o\
335335
diago_bpcg.o\
336+
para_linear_transform.o\
336337
hsolver.o\
337338
hsolver_pw.o\
338339
hsolver_lcaopw.o\

source/module_hsolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ list(APPEND objects
44
diago_david.cpp
55
diago_dav_subspace.cpp
66
diago_bpcg.cpp
7+
para_linear_transform.cpp
78
hsolver_pw.cpp
89
hsolver_lcaopw.cpp
910
hsolver_pw_sdft.cpp
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "para_linear_transform.h"
2+
#include <vector>
3+
#include <algorithm>
4+
namespace hsolver
5+
{
6+
template <typename T, typename Device>
7+
void para_linear_transform_op<T, Device>::operator()(T* A,
8+
const T alpha,
9+
const T beta,
10+
const T* U_global,
11+
const int& nrow,
12+
const int& LDA,
13+
const int& ncol_loc,
14+
const int& ncol_glo,
15+
#ifdef __MPI
16+
MPI_Comm col_world,
17+
#endif
18+
const int rank_col,
19+
const int nproc_col
20+
21+
)
22+
{
23+
const Device* ctx = {};
24+
#ifdef __MPI
25+
if (nproc_col > 1)
26+
{
27+
std::vector<int> colA_loc(nproc_col);
28+
MPI_Allgather(&ncol_loc, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world);
29+
std::vector<int> start_col(nproc_col);
30+
start_col[0] = 0;
31+
for (int ip = 1; ip < nproc_col; ++ip)
32+
{
33+
start_col[ip] = start_col[ip - 1] + colA_loc[ip - 1];
34+
}
35+
int max_col = *std::max_element(colA_loc.begin(), colA_loc.end());
36+
std::vector<MPI_Request> requests(nproc_col);
37+
38+
std::vector<T> A_tmp(max_col * LDA);
39+
T* A_tmp_device = A_tmp.data();
40+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
41+
{
42+
A_tmp_device = nullptr;
43+
resmem_dev_op()(A_tmp_device, max_col * LDA);
44+
}
45+
T* A_tmp2 = nullptr;
46+
resmem_dev_op()(A_tmp2, ncol_loc * LDA);
47+
syncmem_dev_op()(A_tmp2, A, ncol_loc * LDA);
48+
T* A_sum = nullptr;
49+
resmem_dev_op()(A_sum, ncol_loc * LDA);
50+
setmem_dev_op()(A_sum, 0.0, ncol_loc * LDA);
51+
52+
// Send
53+
for (int ip = 0; ip < nproc_col; ++ip)
54+
{
55+
if (rank_col != ip)
56+
{
57+
int size = LDA * ncol_loc;
58+
Parallel_Common::isend_dev<T, Device>(A, size, ip, 0, col_world, &requests[ip], A_tmp.data());
59+
}
60+
}
61+
62+
// Receive
63+
T* U_local = nullptr;
64+
resmem_dev_op()(U_local, max_col * ncol_loc);
65+
const int start = start_col[rank_col];
66+
for (int ip = 0; ip < nproc_col; ++ip)
67+
{
68+
T real_beta = ip == 0 ? beta : 0;
69+
const int start_row = start_col[ip];
70+
const int ncol_ip = colA_loc[ip];
71+
// get U_local
72+
for (int i = 0; i < ncol_loc; ++i)
73+
{
74+
const T* U_glo_tmp = U_global + start_row + (i + start) * ncol_glo;
75+
syncmem_dev_op()(U_local + i * ncol_ip, U_glo_tmp, ncol_ip);
76+
}
77+
78+
if (ip == rank_col)
79+
{
80+
ModuleBase::gemm_op<T, Device>()(ctx,
81+
'N',
82+
'N',
83+
nrow,
84+
ncol_loc,
85+
ncol_ip,
86+
&alpha,
87+
A,
88+
LDA,
89+
U_local,
90+
ncol_ip,
91+
&real_beta,
92+
A_tmp2,
93+
LDA);
94+
}
95+
else
96+
{
97+
int size = LDA * ncol_ip;
98+
MPI_Status status;
99+
Parallel_Common::recv_dev<T, Device>(A_tmp_device, size, ip, 0, col_world, &status, A_tmp.data());
100+
MPI_Wait(&requests[ip], &status);
101+
ModuleBase::gemm_op<T, Device>()(ctx,
102+
'N',
103+
'N',
104+
nrow,
105+
ncol_loc,
106+
ncol_ip,
107+
&alpha,
108+
A_tmp_device,
109+
LDA,
110+
U_local,
111+
ncol_ip,
112+
&real_beta,
113+
A_tmp2,
114+
LDA);
115+
}
116+
// sum all the results
117+
T one = 1.0;
118+
ModuleBase::axpy_op<T, Device>()(ctx, ncol_loc * LDA, &one, A_tmp2, 1, A_sum, 1);
119+
}
120+
syncmem_dev_op()(A, A_sum, ncol_loc * LDA);
121+
delmem_dev_op()(U_local);
122+
delmem_dev_op()(A_tmp2);
123+
delmem_dev_op()(A_sum);
124+
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
125+
{
126+
delmem_dev_op()(A_tmp_device);
127+
}
128+
}
129+
else
130+
#endif
131+
{
132+
T* A_tmp = nullptr;
133+
resmem_dev_op()(A_tmp, LDA * ncol_glo);
134+
syncmem_dev_op()(A_tmp, A, LDA * ncol_loc);
135+
ModuleBase::gemm_op<T, Device>()(ctx,
136+
'N',
137+
'N',
138+
nrow,
139+
ncol_glo,
140+
ncol_glo,
141+
&alpha,
142+
A_tmp,
143+
LDA,
144+
U_global,
145+
ncol_glo,
146+
&beta,
147+
A,
148+
LDA);
149+
delmem_dev_op()(A_tmp);
150+
}
151+
};
152+
153+
template struct para_linear_transform_op<double, base_device::DEVICE_CPU>;
154+
template struct para_linear_transform_op<std::complex<double>, base_device::DEVICE_CPU>;
155+
template struct para_linear_transform_op<std::complex<float>, base_device::DEVICE_CPU>;
156+
#if ((defined __CUDA) || (defined __ROCM))
157+
template struct para_linear_transform_op<double, base_device::DEVICE_GPU>;
158+
template struct para_linear_transform_op<std::complex<double>, base_device::DEVICE_GPU>;
159+
template struct para_linear_transform_op<std::complex<float>, base_device::DEVICE_GPU>;
160+
#endif
161+
} // namespace hsolver
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef __PARA_LINEAR_TRANSFORM_H__
2+
#define __PARA_LINEAR_TRANSFORM_H__
3+
#include "module_base/kernels/math_kernel_op.h"
4+
#include "module_base/module_device/device.h"
5+
#include "module_base/module_device/memory_op.h"
6+
#include "module_base/parallel_device.h"
7+
#ifdef __MPI
8+
#include "mpi.h"
9+
#endif
10+
namespace hsolver
11+
{
12+
13+
template <typename T, typename Device>
14+
struct para_linear_transform_op
15+
{
16+
using syncmem_dev_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
17+
using resmem_dev_op = base_device::memory::resize_memory_op<T, Device>;
18+
using setmem_dev_op = base_device::memory::set_memory_op<T, Device>;
19+
using delmem_dev_op = base_device::memory::delete_memory_op<T, Device>;
20+
/**
21+
* @brief A_global = alpha * A_global * U_global + beta * A_global
22+
* A is a local matrix with nrow rows and ncol_loc columns
23+
* U_global is a matrix with ncol_glo rows and ncol_glo columns
24+
* @example rotate wave functions: A = A * U
25+
* orthogonalize wave functions: A = A - A * U
26+
*
27+
* @param A : input/output matrix
28+
* @param alpha : alpha
29+
* @param beta : beta
30+
* @param U_global : input matrix
31+
* @param nrow : number of rows of A
32+
* @param LDA : leading dimension of A
33+
* @param ncol_loc : number of columns of A
34+
* @param ncol_glo : number of columns and rows of U_global
35+
* @param col_world : column communicator world
36+
* @param rank_col : rank of col_world
37+
* @param nproc_col : number of processes in col_world
38+
*
39+
*/
40+
void operator()(T* A,
41+
const T alpha,
42+
const T beta,
43+
const T* U_global,
44+
const int& nrow,
45+
const int& LDA,
46+
const int& ncol_loc,
47+
const int& ncol_glo,
48+
#ifdef __MPI
49+
MPI_Comm col_world,
50+
#endif
51+
const int rank_col,
52+
const int nproc_col);
53+
};
54+
} // namespace hsolver
55+
#endif

source/module_hsolver/test/CMakeLists.txt

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ if (ENABLE_MPI)
1212
AddTest(
1313
TARGET HSolver_bpcg
1414
LIBS parameter ${math_libs} base psi device container
15-
SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../diago_iter_assist.cpp
15+
SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp
1616
../../module_basis/module_pw/test/test_tool.cpp
1717
../../module_hamilt_general/operator.cpp
1818
../../module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp
@@ -77,13 +77,13 @@ if (ENABLE_MPI)
7777
AddTest(
7878
TARGET HSolver_pw
7979
LIBS parameter ${math_libs} psi device base container
80-
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
80+
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp
8181
)
8282

8383
AddTest(
8484
TARGET HSolver_sdft
8585
LIBS parameter ${math_libs} psi device base container
86-
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
86+
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp
8787
)
8888

8989
if(ENABLE_LCAO)
@@ -159,6 +159,17 @@ AddTest(
159159
SOURCES test_diago_hs_para.cpp ../diag_hs_para.cpp ../diago_pxxxgvx.cpp ../diago_elpa.cpp ../diago_scalapack.cpp
160160
)
161161

162+
AddTest(
163+
TARGET hsolver_linear_trans
164+
LIBS parameter ${math_libs} base device MPI::MPI_CXX
165+
SOURCES test_para_linear_trans.cpp ../para_linear_transform.cpp
166+
)
167+
168+
add_test(NAME hsolver_para_linear_trans
169+
COMMAND mpirun -np 4 ./hsolver_linear_trans
170+
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
171+
)
172+
162173
find_program(BASH bash)
163174
if (ENABLE_MPI)
164175
add_test(NAME HSolver_cg_parallel

0 commit comments

Comments
 (0)