Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions source/module_psi/psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace psi

Range::Range(const size_t range_in)
{
k_first = 1;
k_first = true;
index_1 = 0;
range_1 = range_in;
range_2 = range_in;
Expand All @@ -38,7 +38,8 @@ template <typename T, typename Device> Psi<T, Device>::Psi()

template <typename T, typename Device> Psi<T, Device>::~Psi()
{
if (this->allocate_inside) delete_memory_op()(this->ctx, this->psi);
if (this->allocate_inside) { delete_memory_op()(this->ctx, this->psi);
}
}

template <typename T, typename Device> Psi<T, Device>::Psi(const int* ngk_in)
Expand All @@ -51,7 +52,10 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const int* ngk_in)
template <typename T, typename Device> Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
{
this->k_first = k_first_in;


this->ngk = ngk_in;

this->current_b = 0;
this->current_k = 0;
this->npol = PARAM.globalv.npol;
Expand All @@ -68,7 +72,16 @@ template <typename T, typename Device> Psi<T, Device>::Psi(const int nk_in, cons
template <typename T, typename Device> Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in)
{
this->k_first = k_first_in;
this->ngk = ngk_in;

if (nk_in == 1)
{
this->ngk = nullptr;
}
else
{
this->ngk = ngk_in;
}

this->current_b = 0;
this->current_k = 0;
this->npol = PARAM.globalv.npol;
Expand Down Expand Up @@ -278,12 +291,14 @@ template <typename T, typename Device> void Psi<T, Device>::fix_k(const int ik)
{
assert(ik >= 0);
this->current_k = ik;
if (this->ngk != nullptr && this->npol != 2)
if (this->ngk != nullptr && this->npol != 2) {
this->current_nbasis = this->ngk[ik];
else
} else {
this->current_nbasis = this->nbasis;
}

if (this->k_first)this->current_b = 0;
if (this->k_first) {this->current_b = 0;
}
int base = this->current_b * this->nk * this->nbasis;
if (ik >= this->nk)
{
Expand All @@ -302,7 +317,8 @@ template <typename T, typename Device> void Psi<T, Device>::fix_b(const int ib)
assert(ib >= 0);
this->current_b = ib;

if (!this->k_first)this->current_k = 0;
if (!this->k_first) {this->current_k = 0;
}
int base = this->current_k * this->nbands * this->nbasis;
if (ib >= this->nbands)
{
Expand Down Expand Up @@ -368,12 +384,32 @@ template <typename T, typename Device> int Psi<T, Device>::get_current_b() const

template <typename T, typename Device> int Psi<T, Device>::get_current_nbas() const
{
return this->current_nbasis;
if (this->ngk == nullptr)
{
std::cout << this->nbasis << std::endl;
return this->nbasis;
}
else // this->ngk != nullptr
{
if (this->npol == 1)
{
return this->ngk[this->current_k];
}
else if (this->npol == 2)
{
return this->nbasis;
}
else
{
assert(false && "In Psi Class, this->npol can only be 1 and 2, not other values.");
}
}
}

template <typename T, typename Device> const int& Psi<T, Device>::get_ngk(const int ik_in) const
{
if (!this->ngk) return this->nbasis;
if (!this->ngk) { return this->nbasis;
}
return this->ngk[ik_in];
}

Expand Down
Loading