Skip to content

Commit e87414c

Browse files
authored
Feature: Add planewave parallization support for BPCG method (#5849)
* Subsitute gemm for einsum in rotate_wf * Add planewave parallel support for inner-produce like gemm_op in bpcg * Add reduce for dot ops used in bpcg * Add reduce for manual inner product(for loop) ops used in bpcg * Update docs now that BPCG supports plane wave parallelization. * Update Autotest.sh to run BPCG test with MPI np=4 * remove unused code and redundancies * Update result.ref for BPCG multicore test
1 parent aced178 commit e87414c

File tree

6 files changed

+40
-22
lines changed

6 files changed

+40
-22
lines changed

docs/advanced/scf/hsolver.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
Method of explicit solving KS-equation can be chosen by variable "ks_solver" in INPUT file.
66

7-
When "basis_type = pw", `ks_solver` can be `cg`, `bpcg` or `dav`. The `bpcg` method only supports K-point parallelism currently. The default setting `cg` is recommended, which is band-by-band conjugate gradient diagonalization method. There is a large probability that the use of setting of `dav` , which is block Davidson diagonalization method, can be tried to improve performance.
7+
When "basis_type = pw", `ks_solver` can be `cg`, `bpcg` or `dav`. The default setting `cg` is recommended, which is band-by-band conjugate gradient diagonalization method. There is a large probability that the use of setting of `dav` , which is block Davidson diagonalization method, can be tried to improve performance.
88

99
When "basis_type = lcao", `ks_solver` can be `genelpa` or `scalapack_gvx`. The default setting `genelpa` is recommended, which is based on ELPA (EIGENVALUE SOLVERS FOR PETAFLOP APPLICATIONS) (https://elpa.mpcdf.mpg.de/) and the kernel is auto choosed by GENELPA(https://github.com/pplab/GenELPA), usually faster than the setting of "scalapack_gvx", which is based on ScaLAPACK(Scalable Linear Algebra PACKage)
1010

source/module_hsolver/diago_bpcg.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ void DiagoBPCG<T, Device>::orth_cholesky(
115115
hsub_out.data<T>(),
116116
this->n_band); //ldc
117117

118+
Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
119+
118120
// set hsub matrix to lower format;
119121
ct::kernels::set_matrix<T, ct_Device>()(
120122
'L', hsub_out.data<T>(), this->n_band);
@@ -167,7 +169,6 @@ void DiagoBPCG<T, Device>::orth_projection(
167169
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_in);
168170
// hsub_in = ct::op::einsum("ij,kj->ik", grad_out, psi_in, option);
169171

170-
// this->orth_projection(this->psi, this->hsub, this->grad);
171172
// gemm: hsub_in(n_band x n_band) = psi_in^T(n_band x n_basis) * grad_out(n_basis x n_band)
172173
gemm_op()(this->ctx,
173174
'C',
@@ -184,6 +185,8 @@ void DiagoBPCG<T, Device>::orth_projection(
184185
hsub_in.data<T>(),
185186
this->n_band); //ldc
186187

188+
Parallel_Reduce::reduce_pool(hsub_in.data<T>(), this->n_band * this->n_band);
189+
187190
// set_matrix_op()('L', hsub_in->data<T>(), this->n_band);
188191
option = ct::EinsumOption(
189192
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/-1.0, /*beta=*/1.0, /*Tensor out=*/&grad_out);
@@ -205,6 +208,8 @@ void DiagoBPCG<T, Device>::orth_projection(
205208
grad_out.data<T>(),
206209
this->n_basis); //ldc
207210

211+
// * This type of non inner product like operation does not need reduce!
212+
208213
return;
209214
}
210215

@@ -216,25 +221,25 @@ void DiagoBPCG<T, Device>::rotate_wf(
216221
{
217222
ct::EinsumOption option(
218223
/*conj_x=*/false, /*conj_y=*/false, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&workspace_in);
219-
workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
224+
// workspace_in = ct::op::einsum("ij,jk->ik", hsub_in, psi_out, option);
220225

221-
// this->rotate_wf(hsub_out, psi_out, workspace_in);
222-
// this->orth_cholesky(this->work, this->psi, this->hpsi, this->hsub);
223226
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
224-
// gemm_op()(this->ctx,
225-
// 'N',
226-
// 'N',
227-
// this->n_basis, //m
228-
// this->n_band, //n
229-
// this->n_band, //k
230-
// this->one, //1.0
231-
// psi_out.data<T>(),
232-
// this->n_basis, //lda
233-
// hsub_in.data<T>(),
234-
// this->n_band, //ldb
235-
// this->zero, //0.0
236-
// workspace_in.data<T>(),
237-
// this->n_basis); //ldc
227+
gemm_op()(this->ctx,
228+
'N',
229+
'N',
230+
this->n_basis, //m
231+
this->n_band, //n
232+
this->n_band, //k
233+
this->one, //1.0
234+
psi_out.data<T>(),
235+
this->n_basis, //lda
236+
hsub_in.data<T>(),
237+
this->n_band, //ldb
238+
this->zero, //0.0
239+
workspace_in.data<T>(),
240+
this->n_basis); //ldc
241+
242+
// * This type of non inner product like operation does not need reduce!
238243

239244
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);
240245

@@ -281,6 +286,8 @@ void DiagoBPCG<T, Device>::diag_hsub(
281286
hsub_out.data<T>(),
282287
this->n_band); //ldc
283288

289+
Parallel_Reduce::reduce_pool(hsub_out.data<T>(), this->n_band * this->n_band);
290+
284291
ct::kernels::lapack_dnevd<T, ct_Device>()('V', 'U', hsub_out.data<T>(), this->n_band, eigenvalue_out.data<Real>());
285292

286293
return;

source/module_hsolver/diago_bpcg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ class DiagoBPCG
340340
using resmem_complex_op = ct::kernels::resize_memory<T, ct_Device>;
341341
using syncmem_complex_op = ct::kernels::synchronize_memory<T, ct_Device, ct_Device>;
342342

343+
// note: these operators use template parameter base_device::Device_*
344+
// defined in module_base/module_device/types.h
345+
// different from ct_Device!
343346
using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
344347
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
345348
using gemm_op = hsolver::gemm_op<T, Device>;

source/module_hsolver/kernels/math_kernel_op.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct line_minimize_with_block_op<T, base_device::DEVICE_CPU>
2424
Real theta = 0.0, cos_theta = 0.0, sin_theta = 0.0;
2525
auto A = reinterpret_cast<const Real*>(grad_out + band_idx * n_basis_max);
2626
Real norm = BlasConnector::dot(2 * n_basis, A, 1, A, 1);
27+
Parallel_Reduce::reduce_pool(norm);
2728
norm = 1.0 / sqrt(norm);
2829
for (int basis_idx = 0; basis_idx < n_basis; basis_idx++)
2930
{
@@ -34,6 +35,9 @@ struct line_minimize_with_block_op<T, base_device::DEVICE_CPU>
3435
epsilo_1 += std::real(grad_out[item] * std::conj(hpsi_out[item]));
3536
epsilo_2 += std::real(grad_out[item] * std::conj(hgrad_out[item]));
3637
}
38+
Parallel_Reduce::reduce_pool(epsilo_0);
39+
Parallel_Reduce::reduce_pool(epsilo_1);
40+
Parallel_Reduce::reduce_pool(epsilo_2);
3741
theta = 0.5 * std::abs(std::atan(2 * epsilo_1 / (epsilo_0 - epsilo_2)));
3842
cos_theta = std::cos(theta);
3943
sin_theta = std::sin(theta);
@@ -71,6 +75,7 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_CPU>
7175
T grad_1 = {0.0, 0.0};
7276
auto A = reinterpret_cast<const Real*>(psi_out + band_idx * n_basis_max);
7377
Real norm = BlasConnector::dot(2 * n_basis, A, 1, A, 1);
78+
Parallel_Reduce::reduce_pool(norm);
7479
norm = 1.0 / sqrt(norm);
7580
for (int basis_idx = 0; basis_idx < n_basis; basis_idx++)
7681
{
@@ -79,6 +84,7 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_CPU>
7984
hpsi_out[item] *= norm;
8085
epsilo += std::real(hpsi_out[item] * std::conj(psi_out[item]));
8186
}
87+
Parallel_Reduce::reduce_pool(epsilo);
8288
for (int basis_idx = 0; basis_idx < n_basis; basis_idx++)
8389
{
8490
auto item = band_idx * n_basis_max + basis_idx;
@@ -87,6 +93,8 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_CPU>
8793
err += grad_2;
8894
beta += grad_2 / prec_in[basis_idx]; /// Mark here as we should div the prec?
8995
}
96+
Parallel_Reduce::reduce_pool(err);
97+
Parallel_Reduce::reduce_pool(beta);
9098
for (int basis_idx = 0; basis_idx < n_basis; basis_idx++)
9199
{
92100
auto item = band_idx * n_basis_max + basis_idx;

tests/integrate/102_PW_BPCG/result.ref

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
etotref -4869.74705201
22
etotperatomref -2434.87352600
3-
totalforceref 5.19483000
4-
totalstressref 37241.45334600
3+
totalforceref 5.19522000
4+
totalstressref 37241.49490600
55
pointgroupref C_1
66
spacegroupref C_1
77
nksibzref 8

tests/integrate/Autotest.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ for dir in $testdir; do
250250
TIMEFORMAT='[----------] Time elapsed: %R seconds'
251251
#parallel test
252252
time {
253-
if [ "$case" = "282_NO_RPA" -o "$dir" = "102_PW_BPCG" ]; then
253+
if [ "$case" = "282_NO_RPA" ]; then
254254
mpirun -np 1 $abacus > log.txt
255255
else
256256
mpirun -np $np $abacus > log.txt

0 commit comments

Comments
 (0)