Skip to content

Commit dc7147d

Browse files
authored
Feature: enable cal_force and cal_stress in nscf (#5752)
* Feature: enable cal_force and cal_stress in nscf * Fix: enable force and stress in uspp nscf * update unit tests
1 parent 25865a9 commit dc7147d

File tree

8 files changed

+123
-99
lines changed

8 files changed

+123
-99
lines changed

source/module_elecstate/elecstate_pw.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
238238
}
239239

240240
template <typename T, typename Device>
241-
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
241+
void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
242242
{
243243
const T one{1, 0};
244244
const T zero{0, 0};
@@ -392,6 +392,12 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
392392
}
393393
}
394394
delmem_complex_op()(this->ctx, becp);
395+
}
396+
397+
template <typename T, typename Device>
398+
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
399+
{
400+
this->cal_becsum(psi);
395401

396402
// transform soft charge to recip space using smooth grids
397403
T* rhog = nullptr;

source/module_elecstate/elecstate_pw.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class ElecStatePW : public ElecState
3535

3636
virtual void cal_tau(const psi::Psi<T, Device>& psi);
3737

38+
//! calculate becsum for uspp
39+
void cal_becsum(const psi::Psi<T, Device>& psi);
40+
3841
Real* becsum = nullptr;
3942

4043
//! init rho_data and kin_r_data
@@ -61,7 +64,7 @@ class ElecStatePW : public ElecState
6164

6265
//! calcualte rho for each k
6366
void rhoBandK(const psi::Psi<T, Device>& psi);
64-
67+
6568
//! add to the charge density in reciprocal space the part which is due to the US augmentation.
6669
void add_usrho(const psi::Psi<T, Device>& psi);
6770

source/module_hsolver/hsolver_lcaopw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
281281
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->calEBand();
282282
if (skip_charge)
283283
{
284+
if (PARAM.globalv.use_uspp)
285+
{
286+
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->cal_becsum(psi);
287+
}
284288
ModuleBase::timer::tick("HSolverLIP", "solve");
285289
return;
286290
}

source/module_hsolver/hsolver_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
337337
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->calEBand();
338338
if (skip_charge)
339339
{
340+
if (PARAM.globalv.use_uspp)
341+
{
342+
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->cal_becsum(psi);
343+
}
340344
ModuleBase::timer::tick("HSolverPW", "solve");
341345
return;
342346
}

source/module_hsolver/test/hsolver_supplementary_mock.h

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "module_elecstate/elecstate.h"
2+
#include "module_elecstate/elecstate_pw.h"
33
#include "module_psi/wavefunc.h"
44

55
namespace elecstate
@@ -62,6 +62,46 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge
6262
return;
6363
}
6464

65+
template <typename T, typename Device>
66+
ElecStatePW<T, Device>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
67+
Charge* chg_in,
68+
K_Vectors* pkv_in,
69+
UnitCell* ucell_in,
70+
pseudopot_cell_vnl* ppcell_in,
71+
ModulePW::PW_Basis* rhodpw_in,
72+
ModulePW::PW_Basis* rhopw_in,
73+
ModulePW::PW_Basis_Big* bigpw_in)
74+
: basis(wfc_basis_in)
75+
{
76+
}
77+
78+
template <typename T, typename Device>
79+
ElecStatePW<T, Device>::~ElecStatePW()
80+
{
81+
}
82+
83+
template <typename T, typename Device>
84+
void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
85+
{
86+
}
87+
88+
template <typename T, typename Device>
89+
void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
90+
{
91+
}
92+
93+
template <typename T, typename Device>
94+
void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
95+
{
96+
}
97+
98+
template class ElecStatePW<std::complex<float>, base_device::DEVICE_CPU>;
99+
template class ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>;
100+
#if ((defined __CUDA) || (defined __ROCM))
101+
template class ElecStatePW<std::complex<float>, base_device::DEVICE_GPU>;
102+
template class ElecStatePW<std::complex<double>, base_device::DEVICE_GPU>;
103+
#endif
104+
65105
Potential::~Potential()
66106
{
67107
}

source/module_hsolver/test/test_hsolver_sdft.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,11 @@ Sto_Func<REAL>::Sto_Func()
2323
}
2424
template class Sto_Func<double>;
2525

26-
27-
template <>
28-
elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
29-
Charge* chg_in,
30-
K_Vectors* pkv_in,
31-
UnitCell* ucell_in,
32-
pseudopot_cell_vnl* ppcell_in,
33-
ModulePW::PW_Basis* rhodpw_in,
34-
ModulePW::PW_Basis* rhopw_in,
35-
ModulePW::PW_Basis_Big* bigpw_in)
36-
: basis(wfc_basis_in)
37-
{
38-
}
39-
40-
template<>
41-
elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::~ElecStatePW()
42-
{
43-
}
44-
4526
template<>
4627
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::init_rho_data()
4728
{
4829
}
4930

50-
template<>
51-
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::psiToRho(const psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& psi)
52-
{
53-
}
54-
55-
template<>
56-
void elecstate::ElecStatePW<std::complex<double>, base_device::DEVICE_CPU>::cal_tau(const psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& psi)
57-
{
58-
}
59-
6031
template <typename REAL, typename Device>
6132
StoChe<REAL, Device>::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto)
6233
{

source/module_io/read_input_item_system.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ void ReadInput::item_system()
207207
item.annotation = "if calculate the force at the end of the electronic iteration";
208208
item.reset_value = [](const Input_Item& item, Parameter& para) {
209209
std::vector<std::string> use_force = {"cell-relax", "relax", "md"};
210-
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "nscf", "get_S"};
210+
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "get_S"};
211211
if (std::find(use_force.begin(), use_force.end(), para.input.calculation) != use_force.end())
212212
{
213213
if (!para.input.cal_force)

source/module_relax/relax_driver.cpp

Lines changed: 62 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -54,90 +54,86 @@ void Relax_Driver::relax_driver(ModuleESolver::ESolver* p_esolver, UnitCell& uce
5454
time_t fstart = time(nullptr);
5555
ModuleBase::matrix force;
5656
ModuleBase::matrix stress;
57-
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
58-
{
59-
// I'm considering putting force and stress
60-
// as part of ucell and use ucell to pass information
61-
// back and forth between esolver and relaxation
62-
// but I'll use force and stress explicitly here for now
6357

64-
// calculate the total energy
65-
this->etot = p_esolver->cal_energy();
58+
// I'm considering putting force and stress
59+
// as part of ucell and use ucell to pass information
60+
// back and forth between esolver and relaxation
61+
// but I'll use force and stress explicitly here for now
62+
63+
// calculate the total energy
64+
this->etot = p_esolver->cal_energy();
65+
66+
// calculate and gather all parts of total ionic forces
67+
if (PARAM.inp.cal_force)
68+
{
69+
p_esolver->cal_force(ucell, force);
70+
}
71+
// calculate and gather all parts of stress
72+
if (PARAM.inp.cal_stress)
73+
{
74+
p_esolver->cal_stress(ucell, stress);
75+
}
6676

67-
// calculate and gather all parts of total ionic forces
68-
if (PARAM.inp.cal_force)
77+
if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
78+
{
79+
if (PARAM.inp.relax_new)
6980
{
70-
p_esolver->cal_force(ucell, force);
81+
stop = rl.relax_step(ucell, force, stress, this->etot);
7182
}
72-
// calculate and gather all parts of stress
73-
if (PARAM.inp.cal_stress)
83+
else
7484
{
75-
p_esolver->cal_stress(ucell, stress);
85+
stop = rl_old.relax_step(istep,
86+
this->etot,
87+
ucell,
88+
force,
89+
stress,
90+
force_step,
91+
stress_step); // pengfei Li 2018-05-14
7692
}
77-
78-
if (PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax")
93+
// print structure
94+
// changelog 20240509
95+
// because I move out the dependence on GlobalV from UnitCell::print_stru_file
96+
// so its parameter is calculated here
97+
bool need_orb = PARAM.inp.basis_type == "pw";
98+
need_orb = need_orb && PARAM.inp.psi_initializer;
99+
need_orb = need_orb && PARAM.inp.init_wfc.substr(0, 3) == "nao";
100+
need_orb = need_orb || PARAM.inp.basis_type == "lcao";
101+
need_orb = need_orb || PARAM.inp.basis_type == "lcao_in_pw";
102+
std::stringstream ss, ss1;
103+
ss << PARAM.globalv.global_out_dir << "STRU_ION_D";
104+
ucell.print_stru_file(ss.str(),
105+
PARAM.inp.nspin,
106+
true,
107+
PARAM.inp.calculation == "md",
108+
PARAM.inp.out_mul,
109+
need_orb,
110+
PARAM.globalv.deepks_setorb,
111+
GlobalV::MY_RANK);
112+
113+
if (Ions_Move_Basic::out_stru)
79114
{
80-
if (PARAM.inp.relax_new)
81-
{
82-
stop = rl.relax_step(ucell,force, stress, this->etot);
83-
}
84-
else
85-
{
86-
stop = rl_old.relax_step(istep,
87-
this->etot,
88-
ucell,
89-
force,
90-
stress,
91-
force_step,
92-
stress_step); // pengfei Li 2018-05-14
93-
}
94-
// print structure
95-
// changelog 20240509
96-
// because I move out the dependence on GlobalV from UnitCell::print_stru_file
97-
// so its parameter is calculated here
98-
bool need_orb = PARAM.inp.basis_type == "pw";
99-
need_orb = need_orb && PARAM.inp.psi_initializer;
100-
need_orb = need_orb && PARAM.inp.init_wfc.substr(0, 3) == "nao";
101-
need_orb = need_orb || PARAM.inp.basis_type == "lcao";
102-
need_orb = need_orb || PARAM.inp.basis_type == "lcao_in_pw";
103-
std::stringstream ss, ss1;
104-
ss << PARAM.globalv.global_out_dir << "STRU_ION_D";
105-
ucell.print_stru_file(ss.str(),
115+
ss1 << PARAM.globalv.global_out_dir << "STRU_ION";
116+
ss1 << istep << "_D";
117+
ucell.print_stru_file(ss1.str(),
106118
PARAM.inp.nspin,
107119
true,
108120
PARAM.inp.calculation == "md",
109121
PARAM.inp.out_mul,
110122
need_orb,
111123
PARAM.globalv.deepks_setorb,
112124
GlobalV::MY_RANK);
113-
114-
if (Ions_Move_Basic::out_stru)
115-
{
116-
ss1 << PARAM.globalv.global_out_dir << "STRU_ION";
117-
ss1 << istep << "_D";
118-
ucell.print_stru_file(ss1.str(),
119-
PARAM.inp.nspin,
120-
true,
121-
PARAM.inp.calculation == "md",
122-
PARAM.inp.out_mul,
123-
need_orb,
124-
PARAM.globalv.deepks_setorb,
125-
GlobalV::MY_RANK);
126-
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU_NOW.cif",
127-
ucell,
128-
"# Generated by ABACUS ModuleIO::CifParser",
129-
"data_?");
130-
}
131-
132-
ModuleIO::output_after_relax(stop, p_esolver->conv_esolver, GlobalV::ofs_running);
125+
ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU_NOW.cif",
126+
ucell,
127+
"# Generated by ABACUS ModuleIO::CifParser",
128+
"data_?");
133129
}
134130

135-
#ifdef __RAPIDJSON
136-
// add the energy to outout
137-
Json::add_output_energy(p_esolver->cal_energy() * ModuleBase::Ry_to_eV);
138-
#endif
131+
ModuleIO::output_after_relax(stop, p_esolver->conv_esolver, GlobalV::ofs_running);
139132
}
133+
140134
#ifdef __RAPIDJSON
135+
// add the energy to outout
136+
Json::add_output_energy(p_esolver->cal_energy() * ModuleBase::Ry_to_eV);
141137
// add Json of cell coo stress force
142138
double unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI, 3) * 1.0e-8;
143139
double fac = ModuleBase::Ry_to_eV / 0.529177;

0 commit comments

Comments
 (0)