@@ -32,7 +32,6 @@ template <typename T, typename Device>
3232Psi<T, Device>::Psi()
3333{
3434 this ->npol = PARAM.globalv .npol ;
35- this ->device = base_device::get_device_type<Device>(this ->ctx );
3635}
3736
3837template <typename T, typename Device>
@@ -52,8 +51,9 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
5251 this ->current_b = 0 ;
5352 this ->current_k = 0 ;
5453 this ->npol = PARAM.globalv .npol ;
55- this -> device = base_device::get_device_type<Device>( this -> ctx );
54+
5655 this ->resize (nk_in, nbd_in, nbs_in);
56+
5757 // Currently only GPU's implementation is supported for device recording!
5858 base_device::information::print_device_info<Device>(this ->ctx , GlobalV::ofs_device);
5959 base_device::information::record_device_memory<Device>(this ->ctx ,
@@ -76,7 +76,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
7676 this ->current_b = 0 ;
7777 this ->current_k = 0 ;
7878 this ->npol = PARAM.globalv .npol ;
79- this ->device = base_device::get_device_type<Device>(this ->ctx );
8079 this ->nk = nk_in;
8180 this ->nbands = nbd_in;
8281 this ->nbasis = nbs_in;
@@ -96,7 +95,6 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
9695 this ->current_b = 0 ;
9796 this ->current_k = 0 ;
9897 this ->npol = PARAM.globalv .npol ;
99- this ->device = base_device::get_device_type<Device>(this ->ctx );
10098 this ->nk = nk_in;
10199 this ->nbands = nbd_in;
102100 this ->nbasis = nbs_in;
@@ -111,13 +109,10 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
111109template <typename T, typename Device>
112110Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
113111{
114- assert (nk_in <= psi_in.get_nk ());
115- if (nband_in == 0 )
116- {
117- nband_in = psi_in.get_nbands ();
118- }
112+ assert (nk_in <= psi_in.get_nk () && nk_in > 0 );
113+ assert (nband_in <= psi_in.get_nbands () && nband_in > 0 );
114+
119115 this ->k_first = psi_in.get_k_first ();
120- this ->device = psi_in.device ;
121116 this ->resize (nk_in, nband_in, psi_in.get_nbasis ());
122117 this ->ngk = psi_in.ngk ;
123118 this ->npol = psi_in.npol ;
@@ -139,8 +134,6 @@ template <typename T, typename Device>
139134Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
140135{
141136 this ->k_first = psi_in.get_k_first ();
142- this ->device = base_device::get_device_type<Device>(this ->ctx );
143- assert (this ->device == psi_in.device );
144137 assert (nk_in <= psi_in.get_nk ());
145138 if (nband_in == 0 )
146139 {
@@ -168,7 +161,7 @@ Psi<T, Device>::Psi(const Psi& psi_in)
168161 this ->current_b = psi_in.get_current_b ();
169162 this ->k_first = psi_in.get_k_first ();
170163 // this function will copy psi_in.psi to this->psi no matter the device types of each other.
171- this -> device = base_device::get_device_type<Device>( this -> ctx );
164+
172165 this ->resize (psi_in.get_nk (), psi_in.get_nbands (), psi_in.get_nbasis ());
173166 base_device::memory::synchronize_memory_op<T, Device, Device>()(this ->ctx ,
174167 psi_in.get_device (),
@@ -193,7 +186,7 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
193186 this ->current_b = psi_in.get_current_b ();
194187 this ->k_first = psi_in.get_k_first ();
195188 // this function will copy psi_in.psi to this->psi no matter the device types of each other.
196- this -> device = base_device::get_device_type<Device>( this -> ctx );
189+
197190 this ->resize (psi_in.get_nk (), psi_in.get_nbands (), psi_in.get_nbasis ());
198191
199192 // Specifically, if the Device_in type is CPU and the Device type is GPU.
0 commit comments