@@ -32,7 +32,6 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
3232template <typename T, typename Device>
3333Psi<T, Device>::Psi()
3434{
35- this ->npol = PARAM.globalv .npol ;
3635}
3736
3837template <typename T, typename Device>
@@ -53,7 +52,6 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5352 assert (nbs_in > 0 );
5453
5554 this ->k_first = k_first_in;
56- this ->npol = PARAM.globalv .npol ;
5755 this ->allocate_inside = true ;
5856
5957 this ->ngk = ngk_in; // modify later
@@ -91,7 +89,6 @@ Psi<T, Device>::Psi(const int nk_in,
9189 assert (nbs_in > 0 );
9290
9391 this ->k_first = k_first_in;
94- this ->npol = PARAM.globalv .npol ;
9592 this ->allocate_inside = true ;
9693
9794 this ->ngk = ngk_in.data (); // modify later
@@ -129,7 +126,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
129126 // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
130127
131128 this ->k_first = k_first_in;
132- this ->npol = PARAM.globalv .npol ;
133129 this ->allocate_inside = false ;
134130
135131 this ->ngk = nullptr ;
@@ -161,7 +157,6 @@ Psi<T, Device>::Psi(const int nk_in,
161157 assert (nk_in == 1 );
162158
163159 this ->k_first = k_first_in;
164- this ->npol = PARAM.globalv .npol ;
165160 this ->allocate_inside = true ;
166161
167162 this ->ngk = nullptr ;
@@ -191,7 +186,6 @@ template <typename T, typename Device>
191186Psi<T, Device>::Psi(const Psi& psi_in)
192187{
193188 this ->ngk = psi_in.ngk ;
194- this ->npol = PARAM.globalv .npol ;
195189 this ->nk = psi_in.get_nk ();
196190 this ->nbands = psi_in.get_nbands ();
197191 this ->nbasis = psi_in.get_nbasis ();
@@ -218,7 +212,6 @@ template <typename T_in, typename Device_in>
218212Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
219213{
220214 this ->ngk = psi_in.get_ngk_pointer ();
221- this ->npol = PARAM.globalv .npol ;
222215 this ->nk = psi_in.get_nk ();
223216 this ->nbands = psi_in.get_nbands ();
224217 this ->nbasis = psi_in.get_nbasis ();
@@ -331,7 +324,7 @@ const int& Psi<T, Device>::get_psi_bias() const
331324template <typename T, typename Device>
332325const int & Psi<T, Device>::get_current_ngk() const
333326{
334- if (this ->npol == 1 )
327+ if (this ->get_npol () == 1 )
335328 {
336329 return this ->current_nbasis ;
337330 }
@@ -341,6 +334,19 @@ const int& Psi<T, Device>::get_current_ngk() const
341334 }
342335}
343336
337+ template <typename T, typename Device>
338+ const int & Psi<T, Device>::get_npol() const
339+ {
340+ if (PARAM.inp .nspin == 4 )
341+ {
342+ return 2 ;
343+ }
344+ else
345+ {
346+ return 1 ;
347+ }
348+ }
349+
344350template <typename T, typename Device>
345351const int & Psi<T, Device>::get_nk() const
346352{
@@ -519,13 +525,13 @@ std::tuple<const T*, int> Psi<T, Device>::to_range(const Range& range) const
519525 else if (i1 < 0 ) // [r1, r2] is the range of index1 with length m
520526 {
521527 const T* p = &this ->psi [r1 * (k_first ? this ->nbands : this ->nk ) * this ->nbasis ];
522- int m = (r2 - r1 + 1 ) * this ->npol ;
528+ int m = (r2 - r1 + 1 ) * this ->get_npol () ;
523529 return std::tuple<const T*, int >(p, m);
524530 }
525531 else // [r1, r2] is the range of index2 with length m
526532 {
527533 const T* p = &this ->psi [(i1 * (k_first ? this ->nbands : this ->nk ) + r1) * this ->nbasis ];
528- int m = (r2 - r1 + 1 ) * this ->npol ;
534+ int m = (r2 - r1 + 1 ) * this ->get_npol () ;
529535 return std::tuple<const T*, int >(p, m);
530536 }
531537}
0 commit comments