Skip to content

Commit cdc5aef

Browse files
committed
Fix: Fix the errors in building abacus with libtorch-gpu
1 parent 47bfd69 commit cdc5aef

File tree

4 files changed

+109
-50
lines changed

4 files changed

+109
-50
lines changed

source/source_io/write_mlkedf_descriptors.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,51 @@ void Write_MLKEDF_Descriptors::generateTrainData_KS(
6767
delete ptempRho;
6868
}
6969

70+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
71+
const std::string& out_dir,
72+
psi::Psi<std::complex<float>> *psi,
73+
elecstate::ElecState *pelec,
74+
ModulePW::PW_Basis_K *pw_psi,
75+
ModulePW::PW_Basis *pw_rho,
76+
UnitCell& ucell,
77+
const double* veff
78+
)
79+
{
80+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_double(*psi);
81+
82+
this->generateTrainData_KS(out_dir, &psi_double, pelec, pw_psi, pw_rho, ucell, veff);
83+
}
84+
85+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
86+
const std::string& out_dir,
87+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
88+
elecstate::ElecState *pelec,
89+
ModulePW::PW_Basis_K *pw_psi,
90+
ModulePW::PW_Basis *pw_rho,
91+
UnitCell& ucell,
92+
const double* veff
93+
)
94+
{
95+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu(*psi);
96+
97+
this->generateTrainData_KS(out_dir, &psi_cpu, pelec, pw_psi, pw_rho, ucell, veff);
98+
}
99+
100+
void Write_MLKEDF_Descriptors::generateTrainData_KS(
101+
const std::string& dir,
102+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
103+
elecstate::ElecState *pelec,
104+
ModulePW::PW_Basis_K *pw_psi,
105+
ModulePW::PW_Basis *pw_rho,
106+
UnitCell& ucell,
107+
const double *veff
108+
)
109+
{
110+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU> psi_cpu_double(*psi);
111+
112+
this->generateTrainData_KS(dir, &psi_cpu_double, pelec, pw_psi, pw_rho, ucell, veff);
113+
}
114+
70115
void Write_MLKEDF_Descriptors::generate_descriptor(
71116
const std::string& out_dir,
72117
const double * const *prho,

source/source_io/write_mlkedf_descriptors.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,27 @@ class Write_MLKEDF_Descriptors
4040
ModulePW::PW_Basis *pw_rho,
4141
UnitCell& ucell,
4242
const double *veff
43-
){} // a mock function
43+
);
44+
45+
void generateTrainData_KS(
46+
const std::string& dir,
47+
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi,
48+
elecstate::ElecState *pelec,
49+
ModulePW::PW_Basis_K *pw_psi,
50+
ModulePW::PW_Basis *pw_rho,
51+
UnitCell& ucell,
52+
const double *veff
53+
);
54+
void generateTrainData_KS(
55+
const std::string& dir,
56+
psi::Psi<std::complex<float>, base_device::DEVICE_GPU>* psi,
57+
elecstate::ElecState *pelec,
58+
ModulePW::PW_Basis_K *pw_psi,
59+
ModulePW::PW_Basis *pw_rho,
60+
UnitCell& ucell,
61+
const double *veff
62+
);
63+
4464
void generate_descriptor(
4565
const std::string& out_dir,
4666
const double * const *prho,

source/source_psi/psi.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -171,54 +171,6 @@ Psi<T, Device>::Psi(const Psi& psi_in)
171171
this->psi_current = this->psi + psi_in.get_psi_bias();
172172
}
173173

174-
175-
// Constructor 2-2:
176-
template <typename T, typename Device>
177-
template <typename T_in, typename Device_in>
178-
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
179-
{
180-
181-
this->ngk = psi_in.get_ngk_pointer();
182-
this->nk = psi_in.get_nk();
183-
this->nbands = psi_in.get_nbands();
184-
this->nbasis = psi_in.get_nbasis();
185-
this->current_k = psi_in.get_current_k();
186-
this->current_b = psi_in.get_current_b();
187-
this->k_first = psi_in.get_k_first();
188-
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
189-
190-
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
191-
192-
// Specifically, if the Device_in type is CPU and the Device type is GPU.
193-
// Which means we need to initialize a GPU psi from a given CPU psi.
194-
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
195-
// Finally, synchronize the memory from CPU to GPU.
196-
// This could help to reduce the peak memory usage of device.
197-
if (std::is_same<Device, base_device::DEVICE_GPU>::value && std::is_same<Device_in, base_device::DEVICE_CPU>::value)
198-
{
199-
auto* arr = (T*)malloc(sizeof(T) * psi_in.size());
200-
// cast the memory from T_in to T in CPU
201-
base_device::memory::cast_memory_op<T, T_in, Device_in, Device_in>()(arr,
202-
psi_in.get_pointer()
203-
- psi_in.get_psi_bias(),
204-
psi_in.size());
205-
// synchronize the memory from CPU to GPU
206-
base_device::memory::synchronize_memory_op<T, Device, Device_in>()(this->psi,
207-
arr,
208-
psi_in.size());
209-
free(arr);
210-
}
211-
else
212-
{
213-
base_device::memory::cast_memory_op<T, T_in, Device, Device_in>()(this->psi,
214-
psi_in.get_pointer() - psi_in.get_psi_bias(),
215-
psi_in.size());
216-
}
217-
this->psi_bias = psi_in.get_psi_bias();
218-
this->current_nbasis = psi_in.get_current_nbas();
219-
this->psi_current = this->psi + psi_in.get_psi_bias();
220-
}
221-
222174
template <typename T, typename Device>
223175
void Psi<T, Device>::set_all_psi(const T* another_pointer, const std::size_t size_in)
224176
{

source/source_psi/psi.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,49 @@ class Psi
4848
// Constructor 2-2: initialize a new psi from the given psi_in with a different class template
4949
// in this case, psi_in may have a different device type.
5050
template <typename T_in, typename Device_in = Device>
51-
Psi(const Psi<T_in, Device_in>& psi_in);
51+
Psi(const Psi<T_in, Device_in>& psi_in)
52+
{
53+
54+
this->ngk = psi_in.get_ngk_pointer();
55+
this->nk = psi_in.get_nk();
56+
this->nbands = psi_in.get_nbands();
57+
this->nbasis = psi_in.get_nbasis();
58+
this->current_k = psi_in.get_current_k();
59+
this->current_b = psi_in.get_current_b();
60+
this->k_first = psi_in.get_k_first();
61+
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
62+
63+
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
64+
65+
// Specifically, if the Device_in type is CPU and the Device type is GPU.
66+
// Which means we need to initialize a GPU psi from a given CPU psi.
67+
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
68+
// Finally, synchronize the memory from CPU to GPU.
69+
// This could help to reduce the peak memory usage of device.
70+
if (std::is_same<Device, base_device::DEVICE_GPU>::value && std::is_same<Device_in, base_device::DEVICE_CPU>::value)
71+
{
72+
auto* arr = (T*)malloc(sizeof(T) * psi_in.size());
73+
// cast the memory from T_in to T in CPU
74+
base_device::memory::cast_memory_op<T, T_in, Device_in, Device_in>()(arr,
75+
psi_in.get_pointer()
76+
- psi_in.get_psi_bias(),
77+
psi_in.size());
78+
// synchronize the memory from CPU to GPU
79+
base_device::memory::synchronize_memory_op<T, Device, Device_in>()(this->psi,
80+
arr,
81+
psi_in.size());
82+
free(arr);
83+
}
84+
else
85+
{
86+
base_device::memory::cast_memory_op<T, T_in, Device, Device_in>()(this->psi,
87+
psi_in.get_pointer() - psi_in.get_psi_bias(),
88+
psi_in.size());
89+
}
90+
this->psi_bias = psi_in.get_psi_bias();
91+
this->current_nbasis = psi_in.get_current_nbas();
92+
this->psi_current = this->psi + psi_in.get_psi_bias();
93+
}
5294

5395
// Constructor 3-1: 2D Psi version
5496
// used in hsolver-pw function pointer and somewhere.

0 commit comments

Comments
 (0)