Skip to content

Commit 9870c24

Browse files
committed
Fix: Refactor Psi constructor to separate declaration and implementation for better readability and maintainability
1 parent bb78624 commit 9870c24

File tree

2 files changed

+50
-43
lines changed

2 files changed

+50
-43
lines changed

source/source_psi/psi.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,53 @@ Psi<T, Device>::Psi(const Psi& psi_in)
171171
this->psi_current = this->psi + psi_in.get_psi_bias();
172172
}
173173

174+
// Constructor 2-2
175+
template <typename T, typename Device>
176+
template <typename T_in, typename Device_in>
177+
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
178+
{
179+
180+
this->ngk = psi_in.get_ngk_pointer();
181+
this->nk = psi_in.get_nk();
182+
this->nbands = psi_in.get_nbands();
183+
this->nbasis = psi_in.get_nbasis();
184+
this->current_k = psi_in.get_current_k();
185+
this->current_b = psi_in.get_current_b();
186+
this->k_first = psi_in.get_k_first();
187+
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
188+
189+
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
190+
191+
// Specifically, if the Device_in type is CPU and the Device type is GPU.
192+
// Which means we need to initialize a GPU psi from a given CPU psi.
193+
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
194+
// Finally, synchronize the memory from CPU to GPU.
195+
// This could help to reduce the peak memory usage of device.
196+
if (std::is_same<Device, base_device::DEVICE_GPU>::value && std::is_same<Device_in, base_device::DEVICE_CPU>::value)
197+
{
198+
auto* arr = (T*)malloc(sizeof(T) * psi_in.size());
199+
// cast the memory from T_in to T in CPU
200+
base_device::memory::cast_memory_op<T, T_in, Device_in, Device_in>()(arr,
201+
psi_in.get_pointer()
202+
- psi_in.get_psi_bias(),
203+
psi_in.size());
204+
// synchronize the memory from CPU to GPU
205+
base_device::memory::synchronize_memory_op<T, Device, Device_in>()(this->psi,
206+
arr,
207+
psi_in.size());
208+
free(arr);
209+
}
210+
else
211+
{
212+
base_device::memory::cast_memory_op<T, T_in, Device, Device_in>()(this->psi,
213+
psi_in.get_pointer() - psi_in.get_psi_bias(),
214+
psi_in.size());
215+
}
216+
this->psi_bias = psi_in.get_psi_bias();
217+
this->current_nbasis = psi_in.get_current_nbas();
218+
this->psi_current = this->psi + psi_in.get_psi_bias();
219+
}
220+
174221
template <typename T, typename Device>
175222
void Psi<T, Device>::set_all_psi(const T* another_pointer, const std::size_t size_in)
176223
{
@@ -497,6 +544,8 @@ template Psi<double, base_device::DEVICE_CPU>::Psi(const Psi<double, base_device
497544
template Psi<double, base_device::DEVICE_GPU>::Psi(const Psi<double, base_device::DEVICE_CPU>&);
498545
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
499546
const Psi<std::complex<double>, base_device::DEVICE_GPU>&);
547+
template Psi<std::complex<double>, base_device::DEVICE_CPU>::Psi(
548+
const Psi<std::complex<float>, base_device::DEVICE_GPU>&);
500549
template Psi<std::complex<double>, base_device::DEVICE_GPU>::Psi(
501550
const Psi<std::complex<double>, base_device::DEVICE_CPU>&);
502551
template Psi<std::complex<float>, base_device::DEVICE_GPU>::Psi(

source/source_psi/psi.h

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,49 +48,7 @@ class Psi
4848
// Constructor 2-2: initialize a new psi from the given psi_in with a different class template
4949
// in this case, psi_in may have a different device type.
5050
template <typename T_in, typename Device_in = Device>
51-
Psi(const Psi<T_in, Device_in>& psi_in)
52-
{
53-
54-
this->ngk = psi_in.get_ngk_pointer();
55-
this->nk = psi_in.get_nk();
56-
this->nbands = psi_in.get_nbands();
57-
this->nbasis = psi_in.get_nbasis();
58-
this->current_k = psi_in.get_current_k();
59-
this->current_b = psi_in.get_current_b();
60-
this->k_first = psi_in.get_k_first();
61-
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
62-
63-
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
64-
65-
// Specifically, if the Device_in type is CPU and the Device type is GPU.
66-
// Which means we need to initialize a GPU psi from a given CPU psi.
67-
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
68-
// Finally, synchronize the memory from CPU to GPU.
69-
// This could help to reduce the peak memory usage of device.
70-
if (std::is_same<Device, base_device::DEVICE_GPU>::value && std::is_same<Device_in, base_device::DEVICE_CPU>::value)
71-
{
72-
auto* arr = (T*)malloc(sizeof(T) * psi_in.size());
73-
// cast the memory from T_in to T in CPU
74-
base_device::memory::cast_memory_op<T, T_in, Device_in, Device_in>()(arr,
75-
psi_in.get_pointer()
76-
- psi_in.get_psi_bias(),
77-
psi_in.size());
78-
// synchronize the memory from CPU to GPU
79-
base_device::memory::synchronize_memory_op<T, Device, Device_in>()(this->psi,
80-
arr,
81-
psi_in.size());
82-
free(arr);
83-
}
84-
else
85-
{
86-
base_device::memory::cast_memory_op<T, T_in, Device, Device_in>()(this->psi,
87-
psi_in.get_pointer() - psi_in.get_psi_bias(),
88-
psi_in.size());
89-
}
90-
this->psi_bias = psi_in.get_psi_bias();
91-
this->current_nbasis = psi_in.get_current_nbas();
92-
this->psi_current = this->psi + psi_in.get_psi_bias();
93-
}
51+
Psi(const Psi<T_in, Device_in>& psi_in);
9452

9553
// Constructor 3-1: 2D Psi version
9654
// used in hsolver-pw function pointer and somewhere.

0 commit comments

Comments
 (0)