@@ -53,9 +53,15 @@ class Psi
5353{
5454public:
5555 // Constructor 1: basic
56- Psi (void ){};
56+ Psi (void ){
57+ this ->npol = GlobalV::NPOL;
58+ };
5759 // Constructor 2: specify ngk only, should call resize() later
58- Psi (const int * ngk_in){this ->ngk = ngk_in;}
60+ Psi (const int * ngk_in)
61+ {
62+ this ->ngk = ngk_in;
63+ this ->npol = GlobalV::NPOL;
64+ }
5965 // Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later
6066 Psi (int nk_in, int nbd_in, int nbs_in, const int * ngk_in=nullptr )
6167 {
@@ -76,16 +82,20 @@ class Psi
7682 this ->resize (nk_in, nband_in, psi_in.get_nbasis ());
7783 this ->ngk = psi_in.ngk ;
7884 this ->npol = psi_in.npol ;
79- // copy from Psi from psi_in(current_k, 0, 0),
80- // if size of k is 1, current_k in new Psi is psi_in.current_k
81- const T* tmp = psi_in.get_pointer ();
82- if (nk_in==1 ) for (size_t index=0 ; index<this ->size ();++index)
85+
86+ if (nband_in <= psi_in.get_nbands ())
8387 {
84- psi[index] = tmp[index];
85- // current_k for this Psi only keep the spin index same as the copied Psi
86- this ->current_k = psi_in.get_current_k ();
87- }
88- else for (size_t index=0 ; index<this ->size ();++index) psi[index] = tmp[index];
88+ // copy from Psi from psi_in(current_k, 0, 0),
89+ // if size of k is 1, current_k in new Psi is psi_in.current_k
90+ const T* tmp = psi_in.get_pointer ();
91+ if (nk_in==1 ) for (size_t index=0 ; index<this ->size ();++index)
92+ {
93+ psi[index] = tmp[index];
94+ // current_k for this Psi only keep the spin index same as the copied Psi
95+ this ->current_k = psi_in.get_current_k ();
96+ }
97+ else for (size_t index=0 ; index<this ->size ();++index) psi[index] = tmp[index];
98+ }
8999 }
90100 // initialize the wavefunction coefficient
91101 // only resize and construct function now is used
@@ -235,12 +245,12 @@ class Psi
235245 else
236246 {
237247 const T* p = &this ->psi [(range.index_1 * this ->nbands + range.range_1 ) * this ->nbasis ];
238- int m = range.range_2 - range.range_1 + 1 ;
248+ int m = ( range.range_2 - range.range_1 + 1 )* this -> npol ;
239249 return std::tuple<const T*, int >(p, m);
240250 }
241251 }
242252
243- bool npol = 1 ;
253+ int npol = 1 ;
244254
245255 private:
246256 std::vector<T> psi;
0 commit comments