Skip to content

Commit 91b0281

Browse files
Refactor: remove the Psi Constructors using int* ngk_in (#5745)
* delete Psi(const int* ngk_in); * [pre-commit.ci lite] apply automatic fixes * format psi class * update hsolverpw * [pre-commit.ci lite] apply automatic fixes * add Constructor 8-2 * remove useless code * update operator.cpp call_act func * [pre-commit.ci lite] apply automatic fixes * change test code for psi Constructor * fix unit test bug * remove useless code --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent dc7147d commit 91b0281

File tree

16 files changed

+256
-159
lines changed

16 files changed

+256
-159
lines changed

source/module_hamilt_general/operator.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Operator<T, Device>::Operator(){}
1111
template<typename T, typename Device>
1212
Operator<T, Device>::~Operator()
1313
{
14-
if(this->hpsi != nullptr) delete this->hpsi;
14+
if(this->hpsi != nullptr) { delete this->hpsi;
15+
}
1516
Operator* last = this->next_op;
1617
Operator* last_sub = this->next_sub_op;
1718
while(last != nullptr || last_sub != nullptr)
@@ -61,8 +62,11 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6162
}
6263

6364
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
65+
6466
// a "psi" with the bands of needed range
65-
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis());
67+
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);
68+
69+
6670
switch (op->get_act_type())
6771
{
6872
case 2:
@@ -100,9 +104,11 @@ void Operator<T, Device>::init(const int ik_in)
100104
template<typename T, typename Device>
101105
void Operator<T, Device>::add(Operator* next)
102106
{
103-
if(next==nullptr) return;
107+
if(next==nullptr) { return;
108+
}
104109
next->is_first_node = false;
105-
if(next->next_op != nullptr) this->add(next->next_op);
110+
if(next->next_op != nullptr) { this->add(next->next_op);
111+
}
106112
Operator* last = this;
107113
//loop to end of the chain
108114
while(last->next_op != nullptr)

source/module_hsolver/hsolver_pw.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310
#endif
311311

312312
/// solve eigenvector and eigenvalue for H(k)
313-
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands());
313+
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * psi.get_nbands(), this->wfc_basis->nks);
314314

315315
if (skip_charge)
316316
{
@@ -361,19 +361,27 @@ template <typename T, typename Device>
361361
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
362362
psi::Psi<T, Device>& psi,
363363
std::vector<Real>& pre_condition,
364-
Real* eigenvalue)
364+
Real* eigenvalue,
365+
const int& nk_nums)
365366
{
366367
#ifdef __MPI
367368
const diag_comm_info comm_info = {POOL_WORLD, this->rank_in_pool, this->nproc_in_pool};
368369
#else
369370
const diag_comm_info comm_info = {this->rank_in_pool, this->nproc_in_pool};
370371
#endif
371372

373+
auto ngk_pointer = psi.get_ngk_pointer();
374+
375+
std::vector<int> ngk_vector(nk_nums, 0);
376+
for (int i = 0; i < nk_nums; i++)
377+
{
378+
ngk_vector[i] = ngk_pointer[i];
379+
}
380+
372381
if (this->method == "cg")
373382
{
374383
// wrap the subspace_func into a lambda function
375-
auto ngk_pointer = psi.get_ngk_pointer();
376-
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
384+
auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
377385
// psi_in should be a 2D tensor:
378386
// psi_in.shape() = [nbands, nbasis]
379387
const auto ndim = psi_in.shape().ndim();
@@ -383,12 +391,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
383391
1,
384392
psi_in.shape().dim_size(0),
385393
psi_in.shape().dim_size(1),
386-
ngk_pointer);
394+
ngk_vector);
387395
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
388396
1,
389397
psi_out.shape().dim_size(0),
390398
psi_out.shape().dim_size(1),
391-
ngk_pointer);
399+
ngk_vector);
392400
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
393401
ct::DeviceType::CpuDevice,
394402
ct::TensorShape({psi_in.shape().dim_size(0)}));
@@ -407,7 +415,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
407415
using ct_Device = typename ct::PsiToContainer<Device>::type;
408416

409417
// wrap the hpsi_func and spsi_func into a lambda function
410-
auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
418+
auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
411419
ModuleBase::timer::tick("DiagoCG_New", "hpsi_func");
412420
// psi_in should be a 2D tensor:
413421
// psi_in.shape() = [nbands, nbasis]
@@ -418,7 +426,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
418426
1,
419427
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
420428
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
421-
ngk_pointer);
429+
ngk_vector);
422430
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
423431
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
424432
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
@@ -477,13 +485,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
477485
{
478486
const int nband = psi.get_nbands();
479487
const int nbasis = psi.get_nbasis();
480-
auto ngk_pointer = psi.get_ngk_pointer();
481488
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
482-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
489+
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
483490
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
484491

485492
// Convert "pointer data stucture" to a psi::Psi object
486-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
493+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
487494

488495
psi::Range bands_range(true, 0, 0, nvec - 1);
489496

@@ -499,13 +506,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
499506
}
500507
else if (this->method == "dav_subspace")
501508
{
502-
auto ngk_pointer = psi.get_ngk_pointer();
503509
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
504-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
510+
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
505511
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
506512

507513
// Convert "pointer data stucture" to a psi::Psi object
508-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
514+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
509515

510516
psi::Range bands_range(true, 0, 0, nvec - 1);
511517

@@ -550,15 +556,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
550556
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
551557

552558
// Davidson matrix-blockvector functions
553-
554-
auto ngk_pointer = psi.get_ngk_pointer();
555559
/// wrap hpsi into lambda function, Matrix \times blockvector
556560
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
557-
auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
561+
auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
558562
ModuleBase::timer::tick("David", "hpsi_func");
559563

560564
// Convert pointer of psi_in to a psi::Psi object
561-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_pointer);
565+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, ngk_vector);
562566

563567
psi::Range bands_range(true, 0, 0, nvec - 1);
564568

source/module_hsolver/hsolver_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class HSolverPW
5656
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
5757
psi::Psi<T, Device>& psi,
5858
std::vector<Real>& pre_condition,
59-
Real* eigenvalue);
59+
Real* eigenvalue,
60+
const int& nk_nums);
6061

6162
// calculate the precondition array for diagonalization in PW base
6263
void update_precondition(std::vector<Real>& h_diag, const int ik, const int npw, const Real vl_of_0);

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
5454
this->update_precondition(precondition, ik, this->wfc_basis->npwk[ik], pes->pot->get_vl_of_0());
5555
/// solve eigenvector and eigenvalue for H(k)
5656
double* p_eigenvalues = &(pes->ekb(ik, 0));
57-
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues);
57+
this->hamiltSolvePsiK(pHamilt, psi, precondition, p_eigenvalues, nks);
5858
}
5959

6060
#ifdef __MPI

source/module_hsolver/test/diago_cg_float_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class DiagoCGPrepare
164164
auto psi_wrapper = psi::Psi<std::complex<float>>(
165165
psi_in.data<std::complex<float>>(), 1,
166166
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
167-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
167+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
168168
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
169169
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
170170
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<float>>());

source/module_hsolver/test/diago_cg_real_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class DiagoCGPrepare
167167
auto psi_wrapper = psi::Psi<double>(
168168
psi_in.data<double>(), 1,
169169
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
170-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
170+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
171171
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
172172
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
173173
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<double>());

source/module_hsolver/test/diago_cg_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class DiagoCGPrepare
158158
auto psi_wrapper = psi::Psi<std::complex<double>>(
159159
psi_in.data<std::complex<double>>(), 1,
160160
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
161-
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1));
161+
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
162162
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
163163
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
164164
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<double>>());

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class DiagoDavPrepare
113113
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
114114
const int ld_psi, const int nvec)
115115
{
116-
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, nullptr);
116+
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, true);
117117
psi::Range bands_range(true, 0, 0, nvec-1);
118118
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
119119
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class DiagoDavPrepare
112112
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
113113
const int ld_psi, const int nvec)
114114
{
115-
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, nullptr);
115+
auto psi_iter_wrapper = psi::Psi<double>(psi_in, 1, nvec, ld_psi, true);
116116
psi::Range bands_range(true, 0, 0, nvec-1);
117117
using hpsi_info = typename hamilt::Operator<double>::hpsi_info;
118118
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class DiagoDavPrepare
115115
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
116116
const int ld_psi, const int nvec)
117117
{
118-
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, nullptr);
118+
auto psi_iter_wrapper = psi::Psi<std::complex<double>>(psi_in, 1, nvec, ld_psi, true);
119119
psi::Range bands_range(true, 0, 0, nvec-1);
120120
using hpsi_info = typename hamilt::Operator<std::complex<double>>::hpsi_info;
121121
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

0 commit comments

Comments
 (0)