diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f129d3e422..bc5c16aed5 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -16,7 +16,7 @@ namespace psi Range::Range(const size_t range_in) { - k_first = 1; + k_first = true; index_1 = 0; range_1 = range_in; range_2 = range_in; @@ -38,7 +38,8 @@ template Psi::Psi() template Psi::~Psi() { - if (this->allocate_inside) delete_memory_op()(this->ctx, this->psi); + if (this->allocate_inside) { delete_memory_op()(this->ctx, this->psi); +} } template Psi::Psi(const int* ngk_in) @@ -51,7 +52,10 @@ template Psi::Psi(const int* ngk_in) template Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { this->k_first = k_first_in; + + this->ngk = ngk_in; + this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; @@ -68,7 +72,16 @@ template Psi::Psi(const int nk_in, cons template Psi::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) { this->k_first = k_first_in; - this->ngk = ngk_in; + + if (nk_in == 1) + { + this->ngk = nullptr; + } + else + { + this->ngk = ngk_in; + } + this->current_b = 0; this->current_k = 0; this->npol = PARAM.globalv.npol; @@ -278,12 +291,14 @@ template void Psi::fix_k(const int ik) { assert(ik >= 0); this->current_k = ik; - if (this->ngk != nullptr && this->npol != 2) + if (this->ngk != nullptr && this->npol != 2) { this->current_nbasis = this->ngk[ik]; - else + } else { this->current_nbasis = this->nbasis; +} - if (this->k_first)this->current_b = 0; + if (this->k_first) {this->current_b = 0; +} int base = this->current_b * this->nk * this->nbasis; if (ik >= this->nk) { @@ -302,7 +317,8 @@ template void Psi::fix_b(const int ib) assert(ib >= 0); this->current_b = ib; - if (!this->k_first)this->current_k = 0; + if (!this->k_first) {this->current_k = 0; +} int base = this->current_k * this->nbands * this->nbasis; if (ib >= this->nbands) { @@ -368,12 +384,32 @@ template int Psi::get_current_b() const template int Psi::get_current_nbas() const { - return this->current_nbasis; + if (this->ngk == nullptr) + { + std::cout << this->nbasis << std::endl; + return this->nbasis; + } + else // this->ngk != nullptr + { + if (this->npol == 1) + { + return this->ngk[this->current_k]; + } + else if (this->npol == 2) + { + return this->nbasis; + } + else + { + assert(false && "In Psi Class, this->npol can only be 1 and 2, not other values."); + } + } } template const int& Psi::get_ngk(const int ik_in) const { - if (!this->ngk) return this->nbasis; + if (!this->ngk) { return this->nbasis; +} return this->ngk[ik_in]; }