@@ -171,6 +171,53 @@ Psi<T, Device>::Psi(const Psi& psi_in)
171171 this ->psi_current = this ->psi + psi_in.get_psi_bias ();
172172}
173173
174+ // Constructor 2-2
175+ template <typename T, typename Device>
176+ template <typename T_in, typename Device_in>
177+ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
178+ {
179+
180+ this ->ngk = psi_in.get_ngk_pointer ();
181+ this ->nk = psi_in.get_nk ();
182+ this ->nbands = psi_in.get_nbands ();
183+ this ->nbasis = psi_in.get_nbasis ();
184+ this ->current_k = psi_in.get_current_k ();
185+ this ->current_b = psi_in.get_current_b ();
186+ this ->k_first = psi_in.get_k_first ();
187+ // this function will copy psi_in.psi to this->psi no matter the device types of each other.
188+
189+ this ->resize (psi_in.get_nk (), psi_in.get_nbands (), psi_in.get_nbasis ());
190+
191+ // Specifically, if the Device_in type is CPU and the Device type is GPU.
192+ // Which means we need to initialize a GPU psi from a given CPU psi.
193+ // We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
194+ // Finally, synchronize the memory from CPU to GPU.
195+ // This could help to reduce the peak memory usage of device.
196+ if (std::is_same<Device, base_device::DEVICE_GPU>::value && std::is_same<Device_in, base_device::DEVICE_CPU>::value)
197+ {
198+ auto * arr = (T*)malloc (sizeof (T) * psi_in.size ());
199+ // cast the memory from T_in to T in CPU
200+ base_device::memory::cast_memory_op<T, T_in, Device_in, Device_in>()(arr,
201+ psi_in.get_pointer ()
202+ - psi_in.get_psi_bias (),
203+ psi_in.size ());
204+ // synchronize the memory from CPU to GPU
205+ base_device::memory::synchronize_memory_op<T, Device, Device_in>()(this ->psi ,
206+ arr,
207+ psi_in.size ());
208+ free (arr);
209+ }
210+ else
211+ {
212+ base_device::memory::cast_memory_op<T, T_in, Device, Device_in>()(this ->psi ,
213+ psi_in.get_pointer () - psi_in.get_psi_bias (),
214+ psi_in.size ());
215+ }
216+ this ->psi_bias = psi_in.get_psi_bias ();
217+ this ->current_nbasis = psi_in.get_current_nbas ();
218+ this ->psi_current = this ->psi + psi_in.get_psi_bias ();
219+ }
220+
174221template <typename T, typename Device>
175222void Psi<T, Device>::set_all_psi(const T* another_pointer, const std::size_t size_in)
176223{
@@ -497,6 +544,8 @@ template Psi<double, base_device::DEVICE_CPU>::Psi(const Psi<double, base_device
497544template Psi<double , base_device::DEVICE_GPU>::Psi(const Psi<double , base_device::DEVICE_CPU>&);
498545template Psi<std::complex <double >, base_device::DEVICE_CPU>::Psi(
499546 const Psi<std::complex <double >, base_device::DEVICE_GPU>&);
547+ template Psi<std::complex <double >, base_device::DEVICE_CPU>::Psi(
548+ const Psi<std::complex <float >, base_device::DEVICE_GPU>&);
500549template Psi<std::complex <double >, base_device::DEVICE_GPU>::Psi(
501550 const Psi<std::complex <double >, base_device::DEVICE_CPU>&);
502551template Psi<std::complex <float >, base_device::DEVICE_GPU>::Psi(
0 commit comments