@@ -107,6 +107,34 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
107107 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
108108}
109109
110+
111+ template <typename T, typename Device>
112+ Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
113+ {
114+ assert (nk_in <= psi_in.get_nk ());
115+ if (nband_in == 0 )
116+ {
117+ nband_in = psi_in.get_nbands ();
118+ }
119+ this ->k_first = psi_in.get_k_first ();
120+ this ->device = psi_in.device ;
121+ this ->resize (nk_in, nband_in, psi_in.get_nbasis ());
122+ this ->ngk = psi_in.ngk ;
123+ this ->npol = psi_in.npol ;
124+ if (nband_in <= psi_in.get_nbands ())
125+ {
126+ // copy from Psi from psi_in(current_k, 0, 0),
127+ // if size of k is 1, current_k in new Psi is psi_in.current_k
128+ if (nk_in == 1 )
129+ {
130+ // current_k for this Psi only keep the spin index same as the copied Psi
131+ this ->current_k = psi_in.get_current_k ();
132+ }
133+ synchronize_memory_op ()(this ->ctx , psi_in.get_device (), this ->psi , psi_in.get_pointer (), this ->size ());
134+ }
135+ }
136+
137+
110138template <typename T, typename Device>
111139Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
112140{
@@ -208,11 +236,11 @@ template <typename T, typename Device>
208236void Psi<T, Device>::resize(const int nks_in, const int nbands_in, const int nbasis_in)
209237{
210238 assert (nks_in > 0 && nbands_in >= 0 && nbasis_in > 0 );
211-
239+
212240 // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
213241 resize_memory_op ()(this ->ctx , this ->psi , nks_in * static_cast <std::size_t >(nbands_in) * nbasis_in, " no_record" );
214242
215- this ->zero_out ();
243+ // this->zero_out();
216244
217245 this ->nk = nks_in;
218246 this ->nbands = nbands_in;
0 commit comments