@@ -32,6 +32,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
3232template <typename T, typename Device>
3333Psi<T, Device>::Psi()
3434{
35+ this ->npol = this ->get_npol ();
3536}
3637
3738template <typename T, typename Device>
@@ -51,6 +52,8 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5152 assert (nbd_in >= 0 ); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU
5253 assert (nbs_in > 0 );
5354
55+ this ->npol = this ->get_npol ();
56+
5457 this ->k_first = k_first_in;
5558 this ->allocate_inside = true ;
5659
@@ -88,6 +91,9 @@ Psi<T, Device>::Psi(const int nk_in,
8891 assert (nbd_in > 0 );
8992 assert (nbs_in > 0 );
9093
94+
95+ this ->npol = this ->get_npol ();
96+
9197 this ->k_first = k_first_in;
9298 this ->allocate_inside = true ;
9399
@@ -125,6 +131,9 @@ Psi<T, Device>::Psi(T* psi_pointer,
125131 // Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
126132 // assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func
127133
134+
135+ this ->npol = this ->get_npol ();
136+
128137 this ->k_first = k_first_in;
129138 this ->allocate_inside = false ;
130139
@@ -156,6 +165,8 @@ Psi<T, Device>::Psi(const int nk_in,
156165 // Currently this function only supports nk_in == 1 when called within diagH_subspace_init.
157166 assert (nk_in == 1 );
158167
168+ this ->npol = this ->get_npol ();
169+
159170 this ->k_first = k_first_in;
160171 this ->allocate_inside = true ;
161172
@@ -185,6 +196,8 @@ Psi<T, Device>::Psi(const int nk_in,
185196template <typename T, typename Device>
186197Psi<T, Device>::Psi(const Psi& psi_in)
187198{
199+ this ->npol = this ->get_npol ();
200+
188201 this ->ngk = psi_in.ngk ;
189202 this ->nk = psi_in.get_nk ();
190203 this ->nbands = psi_in.get_nbands ();
@@ -211,6 +224,9 @@ template <typename T, typename Device>
211224template <typename T_in, typename Device_in>
212225Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
213226{
227+
228+ this ->npol = this ->get_npol ();
229+
214230 this ->ngk = psi_in.get_ngk_pointer ();
215231 this ->nk = psi_in.get_nk ();
216232 this ->nbands = psi_in.get_nbands ();
0 commit comments