Skip to content

Commit 71a53a4

Browse files
committed
refactor get_current_nbas func
1 parent f0ff82b commit 71a53a4

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

source/module_psi/psi.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,16 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const int* ngk_in)
5151
template <typename T, typename Device> Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
5252
{
5353
this->k_first = k_first_in;
54-
this->ngk = ngk_in;
54+
55+
if (nk_in == 1)
56+
{
57+
this->ngk = nullptr;
58+
}
59+
else
60+
{
61+
this->ngk = ngk_in;
62+
}
63+
5564
this->current_b = 0;
5665
this->current_k = 0;
5766
this->npol = PARAM.globalv.npol;
@@ -68,7 +77,16 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const int nk_in, cons
6877
template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
6978
{
7079
this->k_first = k_first_in;
71-
this->ngk = ngk_in;
80+
81+
if (nk_in == 1)
82+
{
83+
this->ngk = nullptr;
84+
}
85+
else
86+
{
87+
this->ngk = ngk_in;
88+
}
89+
7290
this->current_b = 0;
7391
this->current_k = 0;
7492
this->npol = PARAM.globalv.npol;
@@ -368,7 +386,26 @@ template <typename T, typename Device> int Psi<T, Device>::get_current_b() const
368386

369387
template <typename T, typename Device> int Psi<T, Device>::get_current_nbas() const
370388
{
371-
return this->current_nbasis;
389+
if (this->ngk == nullptr)
390+
{
391+
std::cout << this->nbasis << std::endl;
392+
return this->nbasis;
393+
}
394+
else // this->ngk != nullptr
395+
{
396+
if (this->npol == 1)
397+
{
398+
return this->ngk[this->current_k];
399+
}
400+
else if (this->npol == 2)
401+
{
402+
return this->nbasis;
403+
}
404+
else
405+
{
406+
assert(false && "In Psi Class, this->npol can only be 1 and 2, not other values.");
407+
}
408+
}
372409
}
373410

374411
template <typename T, typename Device> const int& Psi<T, Device>::get_ngk(const int ik_in) const

0 commit comments

Comments
 (0)