Skip to content

Commit 47ff9b6

Browse files
committed
remove npol value in psi
1 parent 5c04f9a commit 47ff9b6

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

source/module_psi/psi.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
3232
template <typename T, typename Device>
3333
Psi<T, Device>::Psi()
3434
{
35-
this->npol = PARAM.globalv.npol;
3635
}
3736

3837
template <typename T, typename Device>
@@ -53,7 +52,6 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5352
assert(nbs_in > 0);
5453

5554
this->k_first = k_first_in;
56-
this->npol = PARAM.globalv.npol;
5755
this->allocate_inside = true;
5856

5957
this->ngk = ngk_in; // modify later
@@ -91,7 +89,6 @@ Psi<T, Device>::Psi(const int nk_in,
9189
assert(nbs_in > 0);
9290

9391
this->k_first = k_first_in;
94-
this->npol = PARAM.globalv.npol;
9592
this->allocate_inside = true;
9693

9794
this->ngk = ngk_in.data(); // modify later
@@ -129,7 +126,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
129126
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
130127

131128
this->k_first = k_first_in;
132-
this->npol = PARAM.globalv.npol;
133129
this->allocate_inside = false;
134130

135131
this->ngk = nullptr;
@@ -161,7 +157,6 @@ Psi<T, Device>::Psi(const int nk_in,
161157
assert(nk_in == 1);
162158

163159
this->k_first = k_first_in;
164-
this->npol = PARAM.globalv.npol;
165160
this->allocate_inside = true;
166161

167162
this->ngk = nullptr;
@@ -191,7 +186,6 @@ template <typename T, typename Device>
191186
Psi<T, Device>::Psi(const Psi& psi_in)
192187
{
193188
this->ngk = psi_in.ngk;
194-
this->npol = PARAM.globalv.npol;
195189
this->nk = psi_in.get_nk();
196190
this->nbands = psi_in.get_nbands();
197191
this->nbasis = psi_in.get_nbasis();
@@ -218,7 +212,6 @@ template <typename T_in, typename Device_in>
218212
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
219213
{
220214
this->ngk = psi_in.get_ngk_pointer();
221-
this->npol = PARAM.globalv.npol;
222215
this->nk = psi_in.get_nk();
223216
this->nbands = psi_in.get_nbands();
224217
this->nbasis = psi_in.get_nbasis();
@@ -331,7 +324,7 @@ const int& Psi<T, Device>::get_psi_bias() const
331324
template <typename T, typename Device>
332325
const int& Psi<T, Device>::get_current_ngk() const
333326
{
334-
if (this->npol == 1)
327+
if (this->get_npol() == 1)
335328
{
336329
return this->current_nbasis;
337330
}
@@ -341,6 +334,19 @@ const int& Psi<T, Device>::get_current_ngk() const
341334
}
342335
}
343336

337+
template <typename T, typename Device>
338+
const int& Psi<T, Device>::get_npol() const
339+
{
340+
if (PARAM.inp.nspin == 4)
341+
{
342+
return 2;
343+
}
344+
else
345+
{
346+
return 1;
347+
}
348+
}
349+
344350
template <typename T, typename Device>
345351
const int& Psi<T, Device>::get_nk() const
346352
{
@@ -519,13 +525,13 @@ std::tuple<const T*, int> Psi<T, Device>::to_range(const Range& range) const
519525
else if (i1 < 0) // [r1, r2] is the range of index1 with length m
520526
{
521527
const T* p = &this->psi[r1 * (k_first ? this->nbands : this->nk) * this->nbasis];
522-
int m = (r2 - r1 + 1) * this->npol;
528+
int m = (r2 - r1 + 1) * this->get_npol();
523529
return std::tuple<const T*, int>(p, m);
524530
}
525531
else // [r1, r2] is the range of index2 with length m
526532
{
527533
const T* p = &this->psi[(i1 * (k_first ? this->nbands : this->nk) + r1) * this->nbasis];
528-
int m = (r2 - r1 + 1) * this->npol;
534+
int m = (r2 - r1 + 1) * this->get_npol();
529535
return std::tuple<const T*, int>(p, m);
530536
}
531537
}

source/module_psi/psi.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,11 @@ class Psi
138138
std::tuple<const T*, int> to_range(const Range& range) const;
139139

140140

141-
const int& get_npol() const { return this->npol;}
141+
const int& get_npol() const;
142142

143143
private:
144144
T* psi = nullptr; // avoid using C++ STL
145145

146-
int npol = 1;
147-
148146
Device* ctx = {}; // an context identifier for obtaining the device variable
149147

150148
// dimensions

0 commit comments

Comments
 (0)