1212#include " module_base/global_function.h"
1313#include " module_base/kernels/math_kernel_op.h"
1414#include " para_linear_transform.h"
15+ #include " module_parameter/parameter.h"
1516
1617namespace hsolver {
1718
@@ -44,9 +45,9 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nband_l, const i
4445
4546 // All column major tensors
4647
47- this ->beta = std::move (ct::Tensor (r_type, device_type, {this ->n_band }));
48+ this ->beta = std::move (ct::Tensor (r_type, device_type, {this ->n_band_l }));
4849 this ->eigen = std::move (ct::Tensor (r_type, device_type, {this ->n_band }));
49- this ->err_st = std::move (ct::Tensor (r_type, device_type, {this ->n_band }));
50+ this ->err_st = std::move (ct::Tensor (r_type, device_type, {this ->n_band_l }));
5051
5152 this ->hsub = std::move (ct::Tensor (t_type, device_type, {this ->n_band , this ->n_band }));
5253
@@ -175,7 +176,7 @@ void DiagoBPCG<T, Device>::rotate_wf(
175176{
176177 // gemm: workspace_in(n_basis x n_band) = psi_out(n_basis x n_band) * hsub_in(n_band x n_band)
177178 this ->plintrans .act (1.0 , psi_out.data <T>(), hsub_in.data <T>(), 0.0 , workspace_in.data <T>());
178- syncmem_complex_op ()(psi_out.template data <T>(), workspace_in.template data <T>(), this ->n_band * this ->n_basis );
179+ syncmem_complex_op ()(psi_out.template data <T>(), workspace_in.template data <T>(), this ->n_band_l * this ->n_basis );
179180
180181 return ;
181182}
@@ -187,7 +188,7 @@ void DiagoBPCG<T, Device>::calc_hpsi_with_block(
187188 ct::Tensor& hpsi_out)
188189{
189190 // calculate all-band hpsi
190- hpsi_func (psi_in, hpsi_out.data <T>(), this ->n_basis , this ->n_band );
191+ hpsi_func (psi_in, hpsi_out.data <T>(), this ->n_basis , this ->n_band_l );
191192}
192193
193194template <typename T, typename Device>
@@ -256,17 +257,17 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
256257{
257258 const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
258259 // Get the pointer of the input psi
259- this ->psi = std::move (ct::TensorMap (psi_in /* psi_in.get_pointer()*/ , t_type, device_type, {this ->n_band , this ->n_basis }));
260+ this ->psi = std::move (ct::TensorMap (psi_in /* psi_in.get_pointer()*/ , t_type, device_type, {this ->n_band_l , this ->n_basis }));
260261
261262 // Update the precondition array
262263 this ->calc_prec ();
263264
264265 // Improving the initial guess of the wave function psi through a subspace diagonalization.
265266 this ->calc_hsub_with_block (hpsi_func, psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
266267
267- setmem_complex_op ()(this ->grad_old .template data <T>(), 0 , this ->n_basis * this ->n_band );
268+ setmem_complex_op ()(this ->grad_old .template data <T>(), 0 , this ->n_basis * this ->n_band_l );
268269
269- setmem_var_op ()(this ->beta .template data <Real>(), std::numeric_limits<Real>::infinity (), this ->n_band );
270+ setmem_var_op ()(this ->beta .template data <Real>(), std::numeric_limits<Real>::infinity (), this ->n_band_l );
270271
271272 int ntry = 0 ;
272273 int max_iter = current_scf_iter > 1 ?
@@ -290,7 +291,7 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
290291 this ->orth_projection (this ->psi , this ->hsub , this ->grad );
291292
292293 // this->grad_old = this->grad;
293- syncmem_complex_op ()(this ->grad_old .template data <T>(), this ->grad .template data <T>(), n_basis * n_band );
294+ syncmem_complex_op ()(this ->grad_old .template data <T>(), this ->grad .template data <T>(), n_basis * n_band_l );
294295
295296 // Calculate H|grad> matrix
296297 this ->calc_hpsi_with_block (hpsi_func, this ->grad .template data <T>(), /* this->grad_wrapper[0],*/ this ->hgrad );
@@ -311,7 +312,14 @@ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
311312
312313 this ->calc_hsub_with_block_exit (this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
313314
314- syncmem_var_d2h_op ()(eigenvalue_in, this ->eigen .template data <Real>(), this ->n_band );
315+ int start_nband = 0 ;
316+ #ifdef __MPI
317+ if (PARAM.inp .bndpar > 1 )
318+ {
319+ start_nband = this ->plintrans .start_colB [GlobalV::MY_BNDGROUP];
320+ }
321+ #endif
322+ syncmem_var_d2h_op ()(eigenvalue_in, this ->eigen .template data <Real>() + start_nband, this ->n_band_l );
315323
316324 return ;
317325}
0 commit comments