@@ -105,31 +105,44 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
105105 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
106106}
107107
108-
109108template <typename T, typename Device>
110109Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
111110{
112111 assert (nk_in <= psi_in.get_nk () && nk_in > 0 );
113112 assert (nband_in <= psi_in.get_nbands () && nband_in > 0 );
114113
115114 this ->k_first = psi_in.get_k_first ();
116- this ->resize (nk_in, nband_in, psi_in.get_nbasis ());
117- this ->ngk = psi_in.ngk ;
118115 this ->npol = psi_in.npol ;
119- if (nband_in <= psi_in.get_nbands ())
116+ this ->allocate_inside = true ;
117+
118+ this ->nk = nk_in;
119+ this ->nbands = nband_in;
120+ this ->nbasis = psi_in.get_nbasis ();
121+
122+ // This function will delete the psi array first(if psi exist), then malloc a new memory for it.
123+ resize_memory_op ()(this ->ctx ,
124+ this ->psi ,
125+ (static_cast <std::size_t >(this ->nk ) * static_cast <std::size_t >(this ->nbands )
126+ * static_cast <std::size_t >(this ->nbasis )),
127+ " no_record" );
128+ synchronize_memory_op ()(this ->ctx , psi_in.get_device (), this ->psi , psi_in.get_pointer (), this ->size ());
129+
130+ this ->current_k = 0 ;
131+ this ->current_b = 0 ;
132+ this ->current_nbasis = this ->nbasis ;
133+ this ->psi_current = this ->psi ;
134+ this ->psi_bias = 0 ;
135+
136+ if (this ->nk != psi_in.get_nk ())
120137 {
121- // copy from Psi from psi_in(current_k, 0, 0),
122- // if size of k is 1, current_k in new Psi is psi_in.current_k
123- if (nk_in == 1 )
124- {
125- // current_k for this Psi only keep the spin index same as the copied Psi
126- this ->current_k = psi_in.get_current_k ();
127- }
128- synchronize_memory_op ()(this ->ctx , psi_in.get_device (), this ->psi , psi_in.get_pointer (), this ->size ());
138+ this ->ngk = nullptr ;
139+ }
140+ else
141+ {
142+ this ->ngk = psi_in.ngk ;
129143 }
130144}
131145
132-
133146template <typename T, typename Device>
134147Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
135148{
0 commit comments