Skip to content

Commit 14ae255

Browse files
authored
Feature: Output real space wavefunction and partial charge density when device=gpu (#6391)
* Fix GPU output of out_pchg and out_wfc_norm, out_wfc_re_im * GPU integrate test is functional again
1 parent b994d95 commit 14ae255

File tree

12 files changed

+179
-104
lines changed

12 files changed

+179
-104
lines changed

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,27 @@
1111
#include "source_estate/module_charge/symmetry_rho.h"
1212
#include "source_hamilt/module_ewald/H_Ewald_pw.h"
1313
#include "source_hamilt/module_vdw/vdw.h"
14-
#include "source_lcao/module_deltaspin/spin_constrain.h"
15-
#include "source_lcao/module_dftu/dftu.h"
16-
#include "source_pw/module_pwdft/elecond.h"
17-
#include "source_pw/module_pwdft/forces.h"
18-
#include "source_pw/module_pwdft/hamilt_pw.h"
19-
#include "source_pw/module_pwdft/onsite_projector.h"
20-
#include "source_pw/module_pwdft/stress_pw.h"
2114
#include "source_hsolver/diago_iter_assist.h"
2215
#include "source_hsolver/hsolver_pw.h"
2316
#include "source_hsolver/kernels/dngvd_op.h"
2417
#include "source_io/berryphase.h"
2518
#include "source_io/cal_ldos.h"
2619
#include "source_io/get_pchg_pw.h"
2720
#include "source_io/get_wf_pw.h"
21+
#include "source_io/module_parameter/parameter.h"
2822
#include "source_io/numerical_basis.h"
2923
#include "source_io/numerical_descriptor.h"
3024
#include "source_io/to_wannier90_pw.h"
3125
#include "source_io/winput.h"
3226
#include "source_io/write_dos_pw.h"
3327
#include "source_io/write_wfc_pw.h"
34-
#include "source_io/module_parameter/parameter.h"
28+
#include "source_lcao/module_deltaspin/spin_constrain.h"
29+
#include "source_lcao/module_dftu/dftu.h"
30+
#include "source_pw/module_pwdft/elecond.h"
31+
#include "source_pw/module_pwdft/forces.h"
32+
#include "source_pw/module_pwdft/hamilt_pw.h"
33+
#include "source_pw/module_pwdft/onsite_projector.h"
34+
#include "source_pw/module_pwdft/stress_pw.h"
3535

3636
#include <iostream>
3737

@@ -717,16 +717,25 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
717717
//------------------------------------------------------------------
718718
// 4) calculate band-decomposed (partial) charge density in pw basis
719719
//------------------------------------------------------------------
720-
const std::vector<int> out_pchg = PARAM.inp.out_pchg;
721-
if (out_pchg.size() > 0)
720+
if (PARAM.inp.out_pchg.size() > 0)
722721
{
723-
ModuleIO::get_pchg_pw(out_pchg,
722+
if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
723+
{
724+
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
725+
}
726+
727+
// Refresh __kspw_psi
728+
this->__kspw_psi = PARAM.inp.precision == "single"
729+
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
730+
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
731+
732+
ModuleIO::get_pchg_pw(PARAM.inp.out_pchg,
724733
this->kspw_psi->get_nbands(),
725734
PARAM.inp.nspin,
726735
this->pw_rhod->nxyz,
727736
this->chr.ngmc,
728737
&ucell,
729-
this->psi,
738+
this->__kspw_psi,
730739
this->pw_rhod,
731740
this->pw_wfc,
732741
this->ctx,
@@ -943,20 +952,25 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
943952
//----------------------------------------------------------
944953
//! 5) Print out electronic wave functions in real space
945954
//----------------------------------------------------------
946-
const std::vector<int> out_wfc_norm = PARAM.inp.out_wfc_norm;
947-
const std::vector<int> out_wfc_re_im = PARAM.inp.out_wfc_re_im;
948-
if (out_wfc_norm.size() > 0 || out_wfc_re_im.size() > 0)
955+
if (PARAM.inp.out_wfc_norm.size() > 0 || PARAM.inp.out_wfc_re_im.size() > 0)
949956
{
950-
ModuleIO::get_wf_pw(out_wfc_norm,
951-
out_wfc_re_im,
957+
if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
958+
{
959+
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
960+
}
961+
962+
// Refresh __kspw_psi
963+
this->__kspw_psi = PARAM.inp.precision == "single"
964+
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
965+
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
966+
967+
ModuleIO::get_wf_pw(PARAM.inp.out_wfc_norm,
968+
PARAM.inp.out_wfc_re_im,
952969
this->kspw_psi->get_nbands(),
953970
PARAM.inp.nspin,
954-
this->pw_rhod->nx,
955-
this->pw_rhod->ny,
956-
this->pw_rhod->nz,
957971
this->pw_rhod->nxyz,
958972
&ucell,
959-
this->psi,
973+
this->__kspw_psi,
960974
this->pw_wfc,
961975
this->ctx,
962976
this->Pgrid,
@@ -991,29 +1005,29 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
9911005

9921006
ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc;
9931007
write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx,
994-
PARAM.inp.nelec,
995-
PARAM.inp.of_tf_weight,
996-
PARAM.inp.of_vw_weight,
997-
PARAM.inp.of_ml_chi_p,
998-
PARAM.inp.of_ml_chi_q,
999-
PARAM.inp.of_ml_chi_xi,
1000-
PARAM.inp.of_ml_chi_pnl,
1001-
PARAM.inp.of_ml_chi_qnl,
1002-
PARAM.inp.of_ml_nkernel,
1003-
PARAM.inp.of_ml_kernel,
1004-
PARAM.inp.of_ml_kernel_scaling,
1005-
PARAM.inp.of_ml_yukawa_alpha,
1006-
PARAM.inp.of_ml_kernel_file,
1007-
ucell.omega,
1008-
this->pw_rho);
1008+
PARAM.inp.nelec,
1009+
PARAM.inp.of_tf_weight,
1010+
PARAM.inp.of_vw_weight,
1011+
PARAM.inp.of_ml_chi_p,
1012+
PARAM.inp.of_ml_chi_q,
1013+
PARAM.inp.of_ml_chi_xi,
1014+
PARAM.inp.of_ml_chi_pnl,
1015+
PARAM.inp.of_ml_chi_qnl,
1016+
PARAM.inp.of_ml_nkernel,
1017+
PARAM.inp.of_ml_kernel,
1018+
PARAM.inp.of_ml_kernel_scaling,
1019+
PARAM.inp.of_ml_yukawa_alpha,
1020+
PARAM.inp.of_ml_kernel_file,
1021+
ucell.omega,
1022+
this->pw_rho);
10091023

10101024
write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir,
1011-
this->kspw_psi,
1012-
this->pelec,
1013-
this->pw_wfc,
1014-
this->pw_rho,
1015-
ucell,
1016-
this->pelec->pot->get_effective_v(0));
1025+
this->kspw_psi,
1026+
this->pelec,
1027+
this->pw_wfc,
1028+
this->pw_rho,
1029+
ucell,
1030+
this->pelec->pot->get_effective_v(0));
10171031
}
10181032
#endif
10191033
}

source/source_io/get_pchg_pw.h

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
1212
const int nxyz,
1313
const int chr_ngmc,
1414
UnitCell* ucell,
15-
const psi::Psi<std::complex<double>>* psi,
15+
const psi::Psi<std::complex<double>, Device>* kspw_psi,
1616
const ModulePW::PW_Basis* pw_rhod,
1717
const ModulePW::PW_Basis_K* pw_wfc,
1818
const Device* ctx,
@@ -67,9 +67,17 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
6767
bands_picked[i] = static_cast<int>(out_pchg[i]);
6868
}
6969

70+
// Allocate host memory
7071
std::vector<std::complex<double>> wfcr(nxyz);
7172
std::vector<std::vector<double>> rho_band(nspin, std::vector<double>(nxyz));
7273

74+
// Allocate device memory
75+
std::complex<double>* wfcr_device = nullptr;
76+
if (!std::is_same<Device, base_device::DEVICE_CPU>::value)
77+
{
78+
base_device::memory::resize_memory_op<std::complex<double>, Device>()(wfcr_device, nxyz);
79+
}
80+
7381
for (int ib = 0; ib < nbands; ++ib)
7482
{
7583
// Skip the loop iteration if bands_picked[ib] is 0
@@ -91,8 +99,21 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
9199
const int spin_index = kv.isk[ik]; // spin index
92100
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
93101

94-
psi->fix_k(ik);
95-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
102+
kspw_psi->fix_k(ik);
103+
104+
// FFT on device and copy result back to host
105+
if (std::is_same<Device, base_device::DEVICE_CPU>::value)
106+
{
107+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr.data(), ik);
108+
}
109+
else
110+
{
111+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr_device, ik);
112+
113+
base_device::memory::synchronize_memory_op<std::complex<double>,
114+
base_device::DEVICE_CPU,
115+
Device>()(wfcr.data(), wfcr_device, nxyz);
116+
}
96117

97118
// To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
98119
double wg_sum_k = 0.0;
@@ -142,8 +163,21 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
142163
const int spin_index = kv.isk[ik]; // spin index
143164
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
144165

145-
psi->fix_k(ik);
146-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
166+
kspw_psi->fix_k(ik);
167+
168+
// FFT on device and copy result back to host
169+
if (std::is_same<Device, base_device::DEVICE_CPU>::value)
170+
{
171+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr.data(), ik);
172+
}
173+
else
174+
{
175+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr_device, ik);
176+
177+
base_device::memory::synchronize_memory_op<std::complex<double>,
178+
base_device::DEVICE_CPU,
179+
Device>()(wfcr.data(), wfcr_device, nxyz);
180+
}
147181

148182
double w1 = static_cast<double>(kv.wk[ik] / ucell->omega);
149183

source/source_io/get_wf_pw.h

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,16 @@
11
#ifndef GET_WF_PW_H
22
#define GET_WF_PW_H
33

4-
#include "cube_io.h"
5-
#include "source_base/module_device/device.h"
6-
#include "source_base/tool_quit.h"
7-
#include "source_basis/module_pw/pw_basis.h"
8-
#include "source_basis/module_pw/pw_basis_k.h"
9-
#include "source_cell/unitcell.h"
10-
#include "source_estate/elecstate.h"
11-
#include "source_estate/module_charge/symmetry_rho.h"
12-
#include "source_pw/module_pwdft/parallel_grid.h"
13-
#include "source_psi/psi.h"
14-
15-
#include <string>
16-
#include <vector>
17-
184
namespace ModuleIO
195
{
206
template <typename Device>
217
void get_wf_pw(const std::vector<int>& out_wfc_norm,
228
const std::vector<int>& out_wfc_re_im,
239
const int nbands,
2410
const int nspin,
25-
const int nx,
26-
const int ny,
27-
const int nz,
2811
const int nxyz,
2912
UnitCell* ucell,
30-
const psi::Psi<std::complex<double>>* psi,
13+
const psi::Psi<std::complex<double>, Device>* kspw_psi,
3114
const ModulePW::PW_Basis_K* pw_wfc,
3215
const Device* ctx,
3316
const Parallel_Grid& pgrid,
@@ -94,9 +77,17 @@ void get_wf_pw(const std::vector<int>& out_wfc_norm,
9477
bands_picked_re_im[i] = static_cast<int>(out_wfc_re_im[i]);
9578
}
9679

80+
// Allocate host memory
9781
std::vector<std::complex<double>> wfcr_norm(nxyz);
9882
std::vector<std::vector<double>> rho_band_norm(nspin, std::vector<double>(nxyz));
9983

84+
// Allocate device memory
85+
std::complex<double>* wfcr_norm_device = nullptr;
86+
if (!std::is_same<Device, base_device::DEVICE_CPU>::value)
87+
{
88+
base_device::memory::resize_memory_op<std::complex<double>, Device>()(wfcr_norm_device, nxyz);
89+
}
90+
10091
for (int ib = 0; ib < nbands; ++ib)
10192
{
10293
// Skip the loop iteration if bands_picked[ib] is 0
@@ -115,8 +106,22 @@ void get_wf_pw(const std::vector<int>& out_wfc_norm,
115106
const int spin_index = kv.isk[ik]; // spin index
116107
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
117108

118-
psi->fix_k(ik);
119-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr_norm.data(), ik);
109+
kspw_psi->fix_k(ik);
110+
111+
// FFT on device and copy result back to host
112+
if (std::is_same<Device, base_device::DEVICE_CPU>::value)
113+
{
114+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr_norm.data(), ik);
115+
}
116+
else
117+
{
118+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfcr_norm_device, ik);
119+
120+
base_device::memory::synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, Device>()(
121+
wfcr_norm.data(),
122+
wfcr_norm_device,
123+
nxyz);
124+
}
120125

121126
// To ensure the normalization of charge density in multi-k calculation
122127
double wg_sum_k = 0.0;
@@ -159,10 +164,18 @@ void get_wf_pw(const std::vector<int>& out_wfc_norm,
159164
}
160165
}
161166

167+
// Allocate host memory
162168
std::vector<std::complex<double>> wfc_re_im(nxyz);
163169
std::vector<std::vector<double>> rho_band_re(nspin, std::vector<double>(nxyz));
164170
std::vector<std::vector<double>> rho_band_im(nspin, std::vector<double>(nxyz));
165171

172+
// Allocate device memory
173+
std::complex<double>* wfc_re_im_device = nullptr;
174+
if (!std::is_same<Device, base_device::DEVICE_CPU>::value)
175+
{
176+
base_device::memory::resize_memory_op<std::complex<double>, Device>()(wfc_re_im_device, nxyz);
177+
}
178+
166179
for (int ib = 0; ib < nbands; ++ib)
167180
{
168181
// Skip the loop iteration if bands_picked[ib] is 0
@@ -182,8 +195,22 @@ void get_wf_pw(const std::vector<int>& out_wfc_norm,
182195
const int spin_index = kv.isk[ik]; // spin index
183196
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
184197

185-
psi->fix_k(ik);
186-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfc_re_im.data(), ik);
198+
kspw_psi->fix_k(ik);
199+
200+
// FFT on device and copy result back to host
201+
if (std::is_same<Device, base_device::DEVICE_CPU>::value)
202+
{
203+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfc_re_im.data(), ik);
204+
}
205+
else
206+
{
207+
pw_wfc->recip_to_real(ctx, &kspw_psi[0](ib, 0), wfc_re_im_device, ik);
208+
209+
base_device::memory::synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, Device>()(
210+
wfc_re_im.data(),
211+
wfc_re_im_device,
212+
nxyz);
213+
}
187214

188215
// To ensure the normalization of charge density in multi-k calculation
189216
double wg_sum_k = 0.0;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
INPUT_PARAMETERS
2+
#Parameters (1.General)
3+
suffix autotest
4+
calculation scf
5+
6+
nbands 6
7+
symmetry 0
8+
pseudo_dir ../../PP_ORB
9+
10+
device gpu
11+
12+
#Parameters (2.Iteration)
13+
ecutwfc 100
14+
scf_thr 1e-9
15+
scf_nmax 100
16+
17+
#Parameters (3.Basis)
18+
basis_type pw
19+
20+
out_wfc_norm 1*1
21+
22+
pw_seed 1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
GPU version / test the output of out_wfc_norm
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
etotref -211.878925376034
2+
etotperatomref -105.9394626880
3+
wfi1s1k1.cube 21.94721795
4+
wfi1s1k2.cube 19.65847203
5+
wfi1s1k3.cube 19.65846259
6+
wfi1s1k4.cube 19.16322464
7+
wfi1s1k5.cube 19.65846987
8+
wfi1s1k6.cube 19.17198153
9+
wfi1s1k7.cube 18.92365819
10+
wfi1s1k8.cube 19.6584683
11+
totaltimeref 3.26

0 commit comments

Comments
 (0)