From 71a53a4e1c9e8e787179774e0d1d5fca1d4248a0 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 18 Dec 2024 12:59:21 +0000 Subject: [PATCH 1/3] refactor get_current_nbas func --- source/module_psi/psi.cpp | 43 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f129d3e422..b1e3d665e2 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -51,7 +51,16 @@ template Psi::Psi(const int* ngk_in) template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { this->k_first = k_first_in; - this->ngk = ngk_in; + + if (nk_in == 1) + { + this->ngk = nullptr; + } + else + { + this->ngk = ngk_in; + } + this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; @@ -68,7 +77,16 @@ template Psi::Psi(const int nk_in, cons template Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { this->k_first = k_first_in; - this->ngk = ngk_in; + + if (nk_in == 1) + { + this->ngk = nullptr; + } + else + { + this->ngk = ngk_in; + } + this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; @@ -368,7 +386,26 @@ template int Psi::get_current_b() const template int Psi::get_current_nbas() const { - return this->current_nbasis; + if (this->ngk == nullptr) + { + std::cout << this->nbasis << std::endl; + return this->nbasis; + } + else // this->ngk != nullptr + { + if (this->npol == 1) + { + return this->ngk[this->current_k]; + } + else if (this->npol == 2) + { + return this->nbasis; + } + else + { + assert(false && "In Psi Class, this->npol can only be 1 and 2, not other values."); + } + } } template const int& Psi::get_ngk(const int ik_in) const From 405a741be2c971947e8c935e58c02ac513132013 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:35:55 +0000 Subject: [PATCH 2/3] [pre-commit.ci lite] apply automatic fixes --- source/module_psi/psi.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index b1e3d665e2..575d8d4ed5 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -16,7 +16,7 @@ namespace psi Range::Range(const size_t range_in) { - k_first = 1; + k_first = true; index_1 = 0; range_1 = range_in; range_2 = range_in; @@ -38,7 +38,8 @@ template Psi::Psi() template Psi::~Psi() { - if (this->allocate_inside) delete_memory_op()(this->ctx, this->psi); + if (this->allocate_inside) { delete_memory_op()(this->ctx, this->psi); +} } template Psi::Psi(const int* ngk_in) @@ -296,12 +297,14 @@ template void Psi::fix_k(const int ik) { assert(ik >= 0); this->current_k = ik; - if (this->ngk != nullptr && this->npol != 2) + if (this->ngk != nullptr && this->npol != 2) { this->current_nbasis = this->ngk[ik]; - else + } else { this->current_nbasis = this->nbasis; +} - if (this->k_first)this->current_b = 0; + if (this->k_first) {this->current_b = 0; +} int base = this->current_b * this->nk * this->nbasis; if (ik >= this->nk) { @@ -320,7 +323,8 @@ template void Psi::fix_b(const int ib) assert(ib >= 0); this->current_b = ib; - if (!this->k_first)this->current_k = 0; + if (!this->k_first) {this->current_k = 0; +} int base = this->current_k * this->nbands * this->nbasis; if (ib >= this->nbands) { @@ -410,7 +414,8 @@ template int Psi::get_current_nbas() co template const int& Psi::get_ngk(const int ik_in) const { - if (!this->ngk) return this->nbasis; + if (!this->ngk) { return this->nbasis; +} return this->ngk[ik_in]; } From e8c1abd636552134f1266cfba5e425469e8cccda Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 18 Dec 2024 13:55:29 +0000 Subject: [PATCH 3/3] fix bug --- source/module_psi/psi.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 575d8d4ed5..bc5c16aed5 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -53,14 +53,8 @@ template Psi::Psi(const int nk_in, cons { this->k_first = k_first_in; - if (nk_in == 1) - { - this->ngk = nullptr; - } - else - { - this->ngk = ngk_in; - } + + this->ngk = ngk_in; this->current_b = 0; this->current_k = 0;