Skip to content

Commit 4ac1e8a

Browse files
Refactor: remove GlobalC::ucell in esolver (#5569)
* Refactor: remove GlobalC::ucell in esolver * [pre-commit.ci lite] apply automatic fixes * update next_direct * Refactor: put ucell as the first parameter * rename cell to ucell * update unitests * update opt_TN.hpp --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent c9f7973 commit 4ac1e8a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+617
-584
lines changed

source/driver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void Driver::atomic_world()
183183
//--------------------------------------------------
184184

185185
// where the actual stuff is done
186-
this->driver_run();
186+
this->driver_run(GlobalC::ucell);
187187

188188
ModuleBase::timer::finish(GlobalV::ofs_running);
189189
ModuleBase::Memory::print_all(GlobalV::ofs_running);

source/driver.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef DRIVER_H
22
#define DRIVER_H
33

4+
#include "module_cell/unitcell.h"
5+
46
class Driver
57
{
68
public:
@@ -34,7 +36,7 @@ class Driver
3436
void atomic_world();
3537

3638
// the actual calculations
37-
void driver_run();
39+
void driver_run(UnitCell& ucell);
3840
};
3941

4042
#endif

source/driver_run.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
* the configuration-changing subroutine takes force and stress and updates the
2525
* configuration
2626
*/
27-
void Driver::driver_run() {
27+
void Driver::driver_run(UnitCell& ucell)
28+
{
2829
ModuleBase::TITLE("Driver", "driver_line");
2930
ModuleBase::timer::tick("Driver", "driver_line");
3031

@@ -39,37 +40,35 @@ void Driver::driver_run() {
3940
#endif
4041

4142
// the life of ucell should begin here, mohan 2024-05-12
42-
// delete ucell as a GlobalC in near future
43-
GlobalC::ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
44-
Check_Atomic_Stru::check_atomic_stru(GlobalC::ucell,
45-
PARAM.inp.min_dist_coef);
43+
ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
44+
Check_Atomic_Stru::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);
4645

4746
//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
48-
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, GlobalC::ucell);
47+
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell);
4948

5049
//! 3: initialize Esolver and fill json-structure
51-
p_esolver->before_all_runners(PARAM.inp, GlobalC::ucell);
50+
p_esolver->before_all_runners(ucell, PARAM.inp);
5251

5352
// this Json part should be moved to before_all_runners, mohan 2024-05-12
5453
#ifdef __RAPIDJSON
55-
Json::gen_stru_wrapper(&GlobalC::ucell);
54+
Json::gen_stru_wrapper(&ucell);
5655
#endif
5756

5857
const std::string cal_type = PARAM.inp.calculation;
5958

6059
//! 4: different types of calculations
6160
if (cal_type == "md")
6261
{
63-
Run_MD::md_line(GlobalC::ucell, p_esolver, PARAM);
62+
Run_MD::md_line(ucell, p_esolver, PARAM);
6463
}
6564
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax" || cal_type == "nscf")
6665
{
6766
Relax_Driver rl_driver;
68-
rl_driver.relax_driver(p_esolver);
67+
rl_driver.relax_driver(p_esolver, ucell);
6968
}
7069
else if (cal_type == "get_S")
7170
{
72-
p_esolver->runner(0, GlobalC::ucell);
71+
p_esolver->runner(ucell, 0);
7372
}
7473
else
7574
{
@@ -79,11 +78,11 @@ void Driver::driver_run() {
7978
//! test_neighbour(LCAO),
8079
//! gen_bessel(PW), et al.
8180
const int istep = 0;
82-
p_esolver->others(istep);
81+
p_esolver->others(ucell, istep);
8382
}
8483

8584
//! 5: clean up esolver
86-
p_esolver->after_all_runners();
85+
p_esolver->after_all_runners(ucell);
8786

8887
ModuleESolver::clean_esolver(p_esolver);
8988

source/module_base/opt_TN.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef OPT_TN_H
22
#define OPT_TN_H
33

4-
#include <limits>
4+
#include "opt_CG.h"
55

6-
#include "./opt_CG.h"
6+
#include <limits>
77

88
namespace ModuleBase
99
{
@@ -25,7 +25,7 @@ class Opt_TN
2525
{
2626
this->mach_prec_ = std::numeric_limits<double>::epsilon(); // get machine precise
2727
}
28-
~Opt_TN(){};
28+
~Opt_TN() {};
2929

3030
/**
3131
* @brief Allocate the space for the arrays in cg_.
@@ -54,7 +54,9 @@ class Opt_TN
5454
{
5555
this->iter_ = 0;
5656
if (nx_new != 0)
57+
{
5758
this->nx_ = nx_new;
59+
}
5860
this->cg_.refresh(nx_new);
5961
}
6062

@@ -167,17 +169,23 @@ void Opt_TN::next_direct(double* px,
167169
epsilon = this->get_epsilon(px, cg_direct);
168170
// epsilon = 1e-9;
169171
for (int i = 0; i < this->nx_; ++i)
172+
{
170173
temp_x[i] = px[i] + epsilon * cg_direct[i];
174+
}
171175
(t->*p_calGradient)(temp_x, temp_gradient);
172176
for (int i = 0; i < this->nx_; ++i)
177+
{
173178
temp_Hcgd[i] = (temp_gradient[i] - pgradient[i]) / epsilon;
179+
}
174180

175181
// get CG step length and update rdirect
176182
cg_alpha = cg_.step_length(temp_Hcgd, cg_direct, cg_ifPD);
177183
if (cg_ifPD == -1) // Hessian is not positive definite, and cgiter = 1.
178184
{
179185
for (int i = 0; i < this->nx_; ++i)
186+
{
180187
rdirect[i] += cg_alpha * cg_direct[i];
188+
}
181189
flag = -1;
182190
break;
183191
}
@@ -188,14 +196,18 @@ void Opt_TN::next_direct(double* px,
188196
}
189197

190198
for (int i = 0; i < this->nx_; ++i)
199+
{
191200
rdirect[i] += cg_alpha * cg_direct[i];
201+
}
192202

193203
// store residuals used in truncated conditions
194204
last_residual = curr_residual;
195205
curr_residual = cg_.get_residual();
196206
cg_iter = cg_.get_iter();
197207
if (cg_iter == 1)
208+
{
198209
init_residual = curr_residual;
210+
}
199211

200212
// check truncated conditions
201213
// if (curr_residual < 1e-12)

source/module_cell/unitcell.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) {
551551
this->atoms = new Atom[this->ntype]; // atom species.
552552
this->set_atom_flag = true;
553553

554+
this->symm.epsilon = PARAM.inp.symmetry_prec;
555+
this->symm.epsilon_input = PARAM.inp.symmetry_prec;
556+
554557
bool ok = true;
555558
bool ok2 = true;
556559

source/module_esolver/esolver.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
247247
{
248248
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
249249
}
250-
p_esolver->before_all_runners(inp, ucell);
251-
p_esolver->runner(0, ucell); // scf-only
250+
p_esolver->before_all_runners(ucell, inp);
251+
p_esolver->runner(ucell, 0); // scf-only
252252
// force and stress is not needed currently,
253253
// they will be supported after the analytical gradient
254254
// of LR-TDDFT is implemented.

source/module_esolver/esolver.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,26 @@ class ESolver
2020
}
2121

2222
//! initialize the energy solver by using input parameters and cell modules
23-
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) = 0;
23+
virtual void before_all_runners(UnitCell& ucell, const Input_para& inp) = 0;
2424

2525
//! run energy solver
26-
virtual void runner(const int istep, UnitCell& cell) = 0;
26+
virtual void runner(UnitCell& cell, const int istep) = 0;
2727

2828
//! perform post processing calculations
29-
virtual void after_all_runners(){};
29+
virtual void after_all_runners(UnitCell& ucell){};
3030

3131
//! deal with exx and other calculation than scf/md/relax/cell-relax:
3232
//! such as nscf, get_wf and get_pchg
33-
virtual void others(const int istep){};
33+
virtual void others(UnitCell& ucell, const int istep) {};
3434

3535
//! calculate total energy of a given system
3636
virtual double cal_energy() = 0;
3737

3838
//! calcualte forces for the atoms in the given cell
39-
virtual void cal_force(ModuleBase::matrix& force) = 0;
39+
virtual void cal_force(UnitCell& ucell, ModuleBase::matrix& force) = 0;
4040

4141
//! calcualte stress of given cell
42-
virtual void cal_stress(ModuleBase::matrix& stress) = 0;
42+
virtual void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) = 0;
4343

4444
bool conv_esolver = true; // whether esolver is converged
4545

source/module_esolver/esolver_dp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
namespace ModuleESolver
3232
{
3333

34-
void ESolver_DP::before_all_runners(const Input_para& inp, UnitCell& ucell)
34+
void ESolver_DP::before_all_runners(UnitCell& ucell, const Input_para& inp)
3535
{
3636
ucell_ = &ucell;
3737
dp_potential = 0;
@@ -57,7 +57,7 @@ void ESolver_DP::before_all_runners(const Input_para& inp, UnitCell& ucell)
5757
#endif
5858
}
5959

60-
void ESolver_DP::runner(const int istep, UnitCell& ucell)
60+
void ESolver_DP::runner(UnitCell& ucell, const int istep)
6161
{
6262
ModuleBase::TITLE("ESolver_DP", "runner");
6363
ModuleBase::timer::tick("ESolver_DP", "runner");
@@ -127,13 +127,13 @@ double ESolver_DP::cal_energy()
127127
return dp_potential;
128128
}
129129

130-
void ESolver_DP::cal_force(ModuleBase::matrix& force)
130+
void ESolver_DP::cal_force(UnitCell& ucell, ModuleBase::matrix& force)
131131
{
132132
force = dp_force;
133133
ModuleIO::print_force(GlobalV::ofs_running, *ucell_, "TOTAL-FORCE (eV/Angstrom)", force, false);
134134
}
135135

136-
void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
136+
void ESolver_DP::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
137137
{
138138
stress = dp_virial;
139139

@@ -148,7 +148,7 @@ void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
148148
ModuleIO::print_stress("TOTAL-STRESS", stress, true, false);
149149
}
150150

151-
void ESolver_DP::after_all_runners()
151+
void ESolver_DP::after_all_runners(UnitCell& ucell)
152152
{
153153
GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl;
154154
GlobalV::ofs_running << std::setprecision(16);

source/module_esolver/esolver_dp.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ class ESolver_DP : public ESolver
3636
* @param inp input parameters
3737
* @param cell unitcell information
3838
*/
39-
void before_all_runners(const Input_para& inp, UnitCell& cell) override;
39+
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
4040

4141
/**
4242
* @brief Run the DP solver for a given ion/md step and unit cell
4343
*
4444
* @param istep the current ion/md step
4545
* @param cell unitcell information
4646
*/
47-
void runner(const int istep, UnitCell& cell) override;
47+
void runner(UnitCell& cell, const int istep) override;
4848

4949
/**
5050
* @brief get the total energy without ion kinetic energy
@@ -59,21 +59,21 @@ class ESolver_DP : public ESolver
5959
*
6060
* @param force the computed atomic forces
6161
*/
62-
void cal_force(ModuleBase::matrix& force) override;
62+
void cal_force(UnitCell& ucell, ModuleBase::matrix& force) override;
6363

6464
/**
6565
* @brief get the computed lattice virials
6666
*
6767
* @param stress the computed lattice virials
6868
*/
69-
void cal_stress(ModuleBase::matrix& stress) override;
69+
void cal_stress(UnitCell& ucell, ModuleBase::matrix& stress) override;
7070

7171
/**
7272
* @brief Prints the final total energy of the DP model to the output file
7373
*
7474
* This function prints the final total energy of the DP model in eV to the output file along with some formatting.
7575
*/
76-
void after_all_runners() override;
76+
void after_all_runners(UnitCell& ucell) override;
7777

7878
private:
7979
/**

0 commit comments

Comments
 (0)