Skip to content

Commit 7ddd28f

Browse files
committed
fix BPCG
1 parent c15c823 commit 7ddd28f

File tree

18 files changed

+244
-43
lines changed

18 files changed

+244
-43
lines changed

source/module_cell/cal_atoms_info.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,16 @@ class CalAtomsInfo
7373
if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel
7474
{
7575
para.sys.nbands_l = para.inp.nbands / para.inp.bndpar;
76-
if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar)
76+
if (GlobalV::MY_BNDGROUP < para.inp.nbands % para.inp.bndpar)
7777
{
7878
para.sys.nbands_l++;
7979
}
8080
}
81+
// temporary code
82+
if (GlobalV::MY_BNDGROUP == 0 || para.inp.ks_solver == "bpcg")
83+
{
84+
para.sys.ks_run = true;
85+
}
8186
return;
8287
}
8388
};

source/module_elecstate/elecstate_print.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter)
247247
{
248248
// check the band energy.
249249
bool wrong = false;
250-
for (int ib = 0; ib < PARAM.inp.nbands; ++ib)
250+
for (int ib = 0; ib < PARAM.globalv.nbands_l; ++ib)
251251
{
252252
if (std::abs(this->ekb(ik, ib)) > 1.0e10)
253253
{
@@ -269,7 +269,7 @@ void ElecState::print_band(const int& ik, const int& printe, const int& iter)
269269
GlobalV::ofs_running << " Energy (eV) & Occupations for spin=" << this->klist->isk[ik] + 1
270270
<< " K-point=" << ik + 1 << std::endl;
271271
GlobalV::ofs_running << std::setiosflags(std::ios::showpoint);
272-
for (int ib = 0; ib < PARAM.inp.nbands; ib++)
272+
for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++)
273273
{
274274
GlobalV::ofs_running << " " << std::setw(6) << ib + 1 << std::setw(15)
275275
<< this->ekb(ik, ib) * ModuleBase::Ry_to_eV;

source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void Stochastic_Iter<T, Device>::orthog(const int& ik, psi::Psi<T, Device>& psi,
5858
{
5959
ModuleBase::TITLE("Stochastic_Iter", "orthog");
6060
ModuleBase::timer::tick("Stochastic_Iter", "orthog");
61-
const int nbands_l = psi.get_nbands();
61+
int nbands_l = psi.get_nbands();
6262
const int nbands = PARAM.inp.nbands;
6363
// orthogonal part
6464
if (nbands > 0)
@@ -74,24 +74,63 @@ void Stochastic_Iter<T, Device>::orthog(const int& ik, psi::Psi<T, Device>& psi,
7474
// orthogonal part
7575
T* sum = nullptr;
7676
resmem_complex_op()(sum, nbands * nchipk);
77-
// sum(b<NBANDS, a<nchi) = < psi_b | chi_a >
78-
ModuleBase::PGemmCN<T, Device> pmmcn;
77+
78+
if(PARAM.globalv.all_ks_run)
79+
{
80+
// sum(b<NBANDS, a<nchi) = < psi_b | chi_a >
81+
ModuleBase::PGemmCN<T, Device> pmmcn;
7982
#ifdef __MPI
80-
pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2);
83+
pmmcn.set_dimension(BP_WORLD, POOL_WORLD, nbands_l, npwx, nchipk, npwx, npw, nbands, 2);
8184
#else
82-
pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2);
85+
pmmcn.set_dimension(nbands_l, npwx, nchipk, npwx, npw, nbands, 2);
8386
#endif
84-
pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum);
85-
86-
// psi -= psi * sum
87-
hsolver::PLinearTransform<T, Device> pltrans;
87+
pmmcn.multiply(1.0, &psi(ik, 0, 0), wfgout, 0.0, sum);
88+
89+
// psi -= psi * sum
90+
hsolver::PLinearTransform<T, Device> pltrans;
8891
#ifdef __MPI
89-
pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true);
92+
pltrans.set_dimension(npw, nbands_l, nchipk, npwx, BP_WORLD, true);
9093
#else
91-
pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true);
94+
pltrans.set_dimension(npw, nbands_l, nchipk, npwx, true);
9295
#endif
93-
pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout);
94-
96+
pltrans.act(-1.0, &psi(ik, 0, 0), sum, 1.0, wfgout);
97+
}
98+
else
99+
{
100+
// sum(b<NBANDS, a<nchi) = < psi_b | chi_a >
101+
ModuleBase::gemm_op<T, Device>()(ctx,
102+
'C',
103+
'N',
104+
nbands,
105+
nchipk,
106+
npw,
107+
&ModuleBase::ONE,
108+
&psi(ik, 0, 0),
109+
npwx,
110+
wfgout,
111+
npwx,
112+
&ModuleBase::ZERO,
113+
sum,
114+
nbands);
115+
Parallel_Reduce::reduce_pool(sum, nbands * nchipk);
116+
117+
// psi -= psi * sum
118+
ModuleBase::gemm_op<T, Device>()(ctx,
119+
'N',
120+
'N',
121+
npw,
122+
nchipk,
123+
nbands,
124+
&ModuleBase::NEG_ONE,
125+
&psi(ik, 0, 0),
126+
npwx,
127+
sum,
128+
nbands,
129+
&ModuleBase::ONE,
130+
wfgout,
131+
npwx);
132+
}
133+
95134
delmem_complex_op()(sum);
96135
}
97136
ModuleBase::timer::tick("Stochastic_Iter", "orthog");

source/module_hsolver/diago_bpcg.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "module_base/global_function.h"
1313
#include "module_base/kernels/math_kernel_op.h"
1414
#include "para_linear_transform.h"
15+
#include "module_parameter/parameter.h"
1516

1617
namespace hsolver {
1718

@@ -44,9 +45,9 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nband_l, const i
4445

4546
// All column major tensors
4647

47-
this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band}));
48+
this->beta = std::move(ct::Tensor(r_type, device_type, {this->n_band_l}));
4849
this->eigen = std::move(ct::Tensor(r_type, device_type, {this->n_band}));
49-
this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band}));
50+
this->err_st = std::move(ct::Tensor(r_type, device_type, {this->n_band_l}));
5051

5152
this->hsub = std::move(ct::Tensor(t_type, device_type, {this->n_band, this->n_band}));
5253

@@ -175,7 +176,7 @@ void DiagoBPCG<T, Device>::rotate_wf(
175176
{
176177
// gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
177178
this->plintrans.act(1.0, psi_out.data<T>(), hsub_in.data<T>(), 0.0, workspace_in.data<T>());
178-
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band * this->n_basis);
179+
syncmem_complex_op()(psi_out.template data<T>(), workspace_in.template data<T>(), this->n_band_l * this->n_basis);
179180

180181
return;
181182
}
@@ -187,7 +188,7 @@ void DiagoBPCG<T, Device>::calc_hpsi_with_block(
187188
ct::Tensor& hpsi_out)
188189
{
189190
// calculate all-band hpsi
190-
hpsi_func(psi_in, hpsi_out.data<T>(), this->n_basis, this->n_band);
191+
hpsi_func(psi_in, hpsi_out.data<T>(), this->n_basis, this->n_band_l);
191192
}
192193

193194
template<typename T, typename Device>
@@ -256,17 +257,17 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
256257
{
257258
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
258259
// Get the pointer of the input psi
259-
this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band, this->n_basis}));
260+
this->psi = std::move(ct::TensorMap(psi_in /*psi_in.get_pointer()*/, t_type, device_type, {this->n_band_l, this->n_basis}));
260261

261262
// Update the precondition array
262263
this->calc_prec();
263264

264265
// Improving the initial guess of the wave function psi through a subspace diagonalization.
265266
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
266267

267-
setmem_complex_op()(this->grad_old.template data<T>(), 0, this->n_basis * this->n_band);
268+
setmem_complex_op()(this->grad_old.template data<T>(), 0, this->n_basis * this->n_band_l);
268269

269-
setmem_var_op()(this->beta.template data<Real>(), std::numeric_limits<Real>::infinity(), this->n_band);
270+
setmem_var_op()(this->beta.template data<Real>(), std::numeric_limits<Real>::infinity(), this->n_band_l);
270271

271272
int ntry = 0;
272273
int max_iter = current_scf_iter > 1 ?
@@ -290,7 +291,7 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
290291
this->orth_projection(this->psi, this->hsub, this->grad);
291292

292293
// this->grad_old = this->grad;
293-
syncmem_complex_op()(this->grad_old.template data<T>(), this->grad.template data<T>(), n_basis * n_band);
294+
syncmem_complex_op()(this->grad_old.template data<T>(), this->grad.template data<T>(), n_basis * n_band_l);
294295

295296
// Calculate H|grad> matrix
296297
this->calc_hpsi_with_block(hpsi_func, this->grad.template data<T>(), /*this->grad_wrapper[0],*/ this->hgrad);
@@ -311,7 +312,14 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
311312

312313
this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen);
313314

314-
syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data<Real>(), this->n_band);
315+
int start_nband = 0;
316+
#ifdef __MPI
317+
if (PARAM.inp.bndpar > 1)
318+
{
319+
start_nband = this->plintrans.start_colB[GlobalV::MY_BNDGROUP];
320+
}
321+
#endif
322+
syncmem_var_d2h_op()(eigenvalue_in, this->eigen.template data<Real>() + start_nband, this->n_band_l);
315323

316324
return;
317325
}

source/module_hsolver/para_linear_transform.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ namespace hsolver
66
{
77
template <typename T, typename Device>
88
void PLinearTransform<T, Device>::set_dimension(const int nrowA,
9-
const int ncolA,
10-
const int ncolB,
11-
const int LDA,
9+
const int ncolA,
10+
const int ncolB,
11+
const int LDA,
1212
#ifdef __MPI
13-
MPI_Comm col_world,
13+
MPI_Comm col_world,
1414
#endif
15-
const bool localU)
15+
const bool localU)
1616
{
1717
this->nrowA = nrowA;
1818
this->ncolA = ncolA;
@@ -91,13 +91,13 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
9191
T real_beta = ip == 0 ? beta : 0;
9292
const int ncolA_ip = colA_loc[ip];
9393
// get U_tmp
94-
95-
const int start_row = start_colA[ip];
96-
for (int i = 0; i < ncolB; ++i)
97-
{
98-
const T* U_part = U + start_row + (i + start) * ncolA_glo;
99-
syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip);
100-
}
94+
95+
const int start_row = start_colA[ip];
96+
for (int i = 0; i < ncolB; ++i)
97+
{
98+
const T* U_part = U + start_row + (i + start) * ncolA_glo;
99+
syncmem_dev_op()(U_tmp + i * ncolA_ip, U_part, ncolA_ip);
100+
}
101101

102102
if (ip == rank_col)
103103
{

source/module_io/read_set_globalv.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,10 @@ void ReadInput::set_globalv(const Input_para& inp, System_para& sys)
6161
Parallel_Common::bcast_bool(sys.double_grid);
6262
#endif
6363
/// set ks_run
64-
if (GlobalV::MY_BNDGROUP == 0 || inp.ks_solver == "bpcg")
65-
{
66-
sys.ks_run = true;
67-
}
6864
if (inp.ks_solver != "bpcg" && inp.bndpar > 1)
6965
{
7066
sys.all_ks_run = false;
7167
}
72-
7368
}
7469

7570
/// @note Here para.inp has been synchronized of all ranks.

source/module_io/write_istate_info.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase:
4141
<< std::setw(25) << "Kpoint = " << ik_global
4242
<< std::setw(25) << "(" << kv.kvec_d[ik].x << " " << kv.kvec_d[ik].y
4343
<< " " << kv.kvec_d[ik].z << ")" << std::endl;
44-
for (int ib = 0; ib < PARAM.inp.nbands; ib++)
44+
for (int ib = 0; ib < PARAM.globalv.nbands_l; ib++)
4545
{
4646
ofsi2.precision(16);
4747
ofsi2 << std::setw(6) << ib + 1 << std::setw(25)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
INPUT_PARAMETERS
2+
#Parameters (General)
3+
suffix autotest
4+
pseudo_dir ../../PP_ORB
5+
pw_seed 1
6+
7+
gamma_only 0
8+
calculation scf
9+
symmetry 1
10+
out_level ie
11+
smearing_method gaussian
12+
smearing_sigma 0.02
13+
14+
#Parameters (3.PW)
15+
ecutwfc 40
16+
scf_thr 1e-7
17+
scf_nmax 20
18+
bndpar 2
19+
20+
#Parameters (LCAO)
21+
basis_type pw
22+
ks_solver bpcg
23+
device cpu
24+
chg_extrap second-order
25+
out_dm 0
26+
pw_diag_thr 0.00001
27+
28+
cal_force 1
29+
cal_stress 1
30+
31+
mixing_type broyden
32+
mixing_beta 0.4
33+
mixing_gg0 1.5

tests/integrate/102_PW_BPCG_BP/KPT

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
K_POINTS
2+
0
3+
Gamma
4+
2 2 2 0 0 0
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
This test for:
2+
*GaAs-deformation
3+
*PW
4+
*bndpar 2
5+
*kpoints 2*2*2
6+
*sg15 pseudopotential
7+
*smearing_method gauss
8+
*ks_solver bpcg
9+
*mixing_type broyden-kerker
10+
*mixing_beta 0.4

0 commit comments

Comments
 (0)