Skip to content

Commit cb47b72

Browse files
committed
update hsolverpw
1 parent 9bca39a commit cb47b72

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

source/module_hsolver/hsolver_pw.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
365365
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
366366
#endif
367367

368+
auto ngk_pointer = psi.get_ngk_pointer();
369+
368370
if (this->method == "cg")
369371
{
370372
// wrap the subspace_func into a lambda function
371-
auto ngk_pointer = psi.get_ngk_pointer();
372373
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
373374
// psi_in should be a 2D tensor:
374375
// psi_in.shape() = [nbands, nbasis]
@@ -379,12 +380,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
379380
1,
380381
psi_in.shape().dim_size(0),
381382
psi_in.shape().dim_size(1),
382-
ngk_pointer);
383+
nullptr);
383384
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
384385
1,
385386
psi_out.shape().dim_size(0),
386387
psi_out.shape().dim_size(1),
387-
ngk_pointer);
388+
nullptr);
388389
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
389390
ct::DeviceType::CpuDevice,
390391
ct::TensorShape({psi_in.shape().dim_size(0)}));
@@ -414,7 +415,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
414415
1,
415416
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
416417
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
417-
ngk_pointer);
418+
nullptr);
418419
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
419420
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
420421
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
@@ -473,13 +474,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
473474
{
474475
const int nband = psi.get_nbands();
475476
const int nbasis = psi.get_nbasis();
476-
auto ngk_pointer = psi.get_ngk_pointer();
477477
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478478
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479479
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
480480

481481
// Convert "pointer data stucture" to a psi::Psi object
482-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
482+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
483483

484484
psi::Range bands_range(true, 0, 0, nvec - 1);
485485

@@ -495,13 +495,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
495495
}
496496
else if (this->method == "dav_subspace")
497497
{
498-
auto ngk_pointer = psi.get_ngk_pointer();
499498
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
500499
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
501500
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
502501

503502
// Convert "pointer data stucture" to a psi::Psi object
504-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
503+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
505504

506505
psi::Range bands_range(true, 0, 0, nvec - 1);
507506

@@ -546,15 +545,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
546545
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
547546

548547
// Davidson matrix-blockvector functions
549-
550-
auto ngk_pointer = psi.get_ngk_pointer();
551548
/// wrap hpsi into lambda function, Matrix \times blockvector
552549
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
553550
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
554551
ModuleBase::timer::tick("David", "hpsi_func");
555552

556553
// Convert pointer of psi_in to a psi::Psi object
557-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
554+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, nullptr);
558555

559556
psi::Range bands_range(true, 0, 0, nvec - 1);
560557

0 commit comments

Comments
 (0)