Skip to content

Commit 4998c0c

Browse files
committed
move setup_pw.cpp to setup_pwwfc.cpp in module_pwdft
1 parent 11d9427 commit 4998c0c

File tree

8 files changed

+121
-116
lines changed

8 files changed

+121
-116
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ OBJS_ESOLVER=esolver.o\
274274
esolver_of_tool.o\
275275
esolver_of_interface.o\
276276
pw_others.o\
277-
pw_setup.o\
278277

279278
OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
280279
esolver_ks_lcao_tddft.o\
@@ -725,6 +724,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
725724
fp_energy.o\
726725
setup_pot.o\
727726
setup_pwrho.o\
727+
setup_pwwfc.o\
728728
forces.o\
729729
forces_us.o\
730730
forces_nl.o\

source/source_esolver/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ list(APPEND objects
1313
esolver_of_interface.cpp
1414
esolver_of_tool.cpp
1515
pw_others.cpp
16-
pw_setup.cpp
1716
)
1817
if(ENABLE_LCAO)
1918
list(APPEND objects

source/source_esolver/esolver_ks.cpp

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "esolver_ks.h"
2-
#include "pw_setup.h" // setup plane wave
32

43
#include "source_base/timer.h"
54
#include "source_base/global_variable.h"
@@ -25,6 +24,7 @@
2524

2625
#include "source_estate/update_pot.h" // mohan add 20251016
2726
#include "source_estate/module_charge/chgmixing.h" // mohan add 20251018
27+
#include "source_pw/module_pwdft/setup_pwwfc.h" // mohan add 20251018
2828

2929
namespace ModuleESolver
3030
{
@@ -40,10 +40,12 @@ ESolver_KS<T, Device>::~ESolver_KS()
4040
// do not add any codes in this deconstructor funcion
4141
//****************************************************
4242
delete this->psi;
43-
delete this->pw_wfc;
4443
delete this->p_hamilt;
4544
delete this->p_chgmix;
4645
this->ppcell.release_memory();
46+
47+
// mohan add 2025-10-18, should be put int clean() function
48+
pw::teardown_pwwfc(this->pw_wfc);
4749
}
4850

4951

@@ -65,69 +67,47 @@ void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para
6567
this->niter = maxniter;
6668
this->drho = 0.0;
6769

68-
std::string fft_device = inp.device;
69-
70-
//! 3) setup pw_wfc
71-
// currently LCAO doesn't support GPU acceleration of FFT
72-
if(inp.basis_type == "lcao")
73-
{
74-
fft_device = "cpu";
75-
}
76-
std::string fft_precision = inp.precision;
77-
#ifdef __ENABLE_FLOAT_FFTW
78-
if (inp.cal_cond && inp.esolver_type == "sdft")
79-
{
80-
fft_precision = "mixing";
81-
}
82-
#endif
83-
84-
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, fft_precision);
85-
86-
// for LCAO calculations, we need to set bx, by, and bz
87-
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
88-
tmp->setbxyz(inp.bx, inp.by, inp.bz);
70+
// cell_factor
71+
this->ppcell.cell_factor = inp.cell_factor;
8972

90-
//! 4) setup charge mixing
73+
//! 3) setup charge mixing
9174
p_chgmix = new Charge_Mixing();
9275
p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod);
9376

94-
// cell_factor
95-
this->ppcell.cell_factor = inp.cell_factor;
96-
9777
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL");
9878

99-
//! 5) setup Exc for the first element '0', because all elements have same exc
79+
//! 4) setup Exc for the first element '0', because all elements have same exc
10080
XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func);
10181

102-
//! 6) setup the charge mixing parameters
82+
//! 5) setup the charge mixing parameters
10383
p_chgmix->set_mixing(inp.mixing_mode, inp.mixing_beta, inp.mixing_ndim,
10484
inp.mixing_gg0, inp.mixing_tau, inp.mixing_beta_mag, inp.mixing_gg0_mag,
10585
inp.mixing_gg0_min, inp.mixing_angle, inp.mixing_dmr, ucell.omega, ucell.tpiba);
10686

10787
p_chgmix->init_mixing();
10888

109-
//! 7) symmetry analysis should be performed every time the cell is changed
89+
//! 6) symmetry analysis should be performed every time the cell is changed
11090
if (ModuleSymmetry::Symmetry::symm_flag == 1)
11191
{
11292
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
11393
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
11494
}
11595

116-
//! 8) Setup the k points according to symmetry.
96+
//! 7) Setup the k points according to symmetry.
11797
this->kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
11898
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
11999

120-
//! 9) print information
100+
//! 8) print information
121101
ModuleIO::print_parameters(ucell, this->kv, inp);
122102

123-
//! 10) setup plane wave for electronic wave functions
124-
ModuleESolver::pw_setup(inp, ucell, *this->pw_rho, this->kv, *this->pw_wfc);
103+
//! 9) setup plane wave for electronic wave functions
104+
pw::setup_pwwfc(inp, ucell, *this->pw_rho, this->kv, this->pw_wfc);
125105

126-
//! 11) parallel of FFT grid
106+
//! 10) parallel of FFT grid
127107
Pgrid.init(this->pw_rhod->nx, this->pw_rhod->ny, this->pw_rhod->nz,
128108
this->pw_rhod->nplane, this->pw_rhod->nrxx, pw_big->nbz, pw_big->bz);
129109

130-
//! 12) calculate the structure factor
110+
//! 11) calculate the structure factor
131111
this->sf.setup_structure_factor(&ucell, Pgrid, this->pw_rhod);
132112
}
133113

source/source_esolver/pw_setup.cpp

Lines changed: 0 additions & 52 deletions
This file was deleted.

source/source_esolver/pw_setup.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

source/source_pw/module_pwdft/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ list(APPEND objects
1414
operator_pw/exx_pw_pot.cpp
1515
setup_pot.cpp
1616
setup_pwrho.cpp
17+
setup_pwwfc.cpp
1718
forces_nl.cpp
1819
forces_cc.cpp
1920
forces_scc.cpp
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#include "source_pw/module_pwdft/setup_pwwfc.h" // pw_wfc
2+
#include "source_base/parallel_comm.h" // POOL_WORLD
3+
#include "source_io/print_info.h" // print information
4+
5+
void pw::teardown_pwwfc(ModulePW::PW_Basis_K* &pw_wfc)
6+
{
7+
delete pw_wfc;
8+
}
9+
10+
void pw::setup_pwwfc(const Input_para& inp,
11+
const UnitCell& ucell,
12+
const ModulePW::PW_Basis& pw_rho,
13+
K_Vectors& kv,
14+
ModulePW::PW_Basis_K* &pw_wfc)
15+
{
16+
ModuleBase::TITLE("pw", "pw_setup");
17+
18+
std::string fft_device = inp.device;
19+
20+
//! setup pw_wfc
21+
// currently LCAO doesn't support GPU acceleration of FFT
22+
if(inp.basis_type == "lcao")
23+
{
24+
fft_device = "cpu";
25+
}
26+
std::string fft_precision = inp.precision;
27+
#ifdef __ENABLE_FLOAT_FFTW
28+
if (inp.cal_cond && inp.esolver_type == "sdft")
29+
{
30+
fft_precision = "mixing";
31+
}
32+
#endif
33+
34+
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, fft_precision);
35+
36+
37+
// for LCAO calculations, we need to set bx, by, and bz
38+
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
39+
tmp->setbxyz(inp.bx, inp.by, inp.bz);
40+
41+
42+
43+
//! new plane wave basis, fft grids, etc.
44+
#ifdef __MPI
45+
pw_wfc->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
46+
#endif
47+
48+
pw_wfc->initgrids(inp.ref_cell_factor * ucell.lat0,
49+
ucell.latvec,
50+
pw_rho.nx,
51+
pw_rho.ny,
52+
pw_rho.nz);
53+
54+
pw_wfc->initparameters(false, inp.ecutwfc, kv.get_nks(), kv.kvec_d.data());
55+
56+
#ifdef __MPI
57+
if (inp.pw_seed > 0)
58+
{
59+
MPI_Allreduce(MPI_IN_PLACE, &pw_wfc->ggecut, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
60+
}
61+
// qianrui add 2021-8-13 to make different kpar parameters can get the same
62+
// results
63+
#endif
64+
65+
pw_wfc->fft_bundle.initfftmode(inp.fft_mode);
66+
pw_wfc->setuptransform();
67+
68+
//! initialize the number of plane waves for each k point
69+
for (int ik = 0; ik < kv.get_nks(); ++ik)
70+
{
71+
kv.ngk[ik] = pw_wfc->npwk[ik];
72+
}
73+
74+
pw_wfc->collect_local_pw(inp.erf_ecut, inp.erf_height, inp.erf_sigma);
75+
76+
ModuleIO::print_wfcfft(inp, *pw_wfc, GlobalV::ofs_running);
77+
78+
return;
79+
}
80+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef SETUP_PWWFC_H
2+
#define SETUP_PWWFC_H
3+
4+
#include "source_io/module_parameter/parameter.h" // input parameters
5+
#include "source_cell/unitcell.h" // cell information
6+
#include "source_cell/klist.h" // k-points
7+
#include "source_basis/module_pw/pw_basis.h" // pw_rho
8+
#include "source_basis/module_pw/pw_basis_k.h" // pw_wfc
9+
10+
namespace pw
11+
{
12+
13+
void teardown_pwwfc(ModulePW::PW_Basis_K* &pw_wfc);
14+
15+
void setup_pwwfc(const Input_para& inp,
16+
const UnitCell& ucell,
17+
const ModulePW::PW_Basis& pw_rho,
18+
K_Vectors& kv,
19+
ModulePW::PW_Basis_K* &pw_wfc);
20+
21+
}
22+
23+
#endif

0 commit comments

Comments
 (0)