Skip to content

Commit e74aa2b

Browse files
committed
Refactor: remove GlobalC::ucell in esolver
1 parent e0202e0 commit e74aa2b

38 files changed

+481
-475
lines changed

source/driver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void Driver::atomic_world()
183183
//--------------------------------------------------
184184

185185
// where the actual stuff is done
186-
this->driver_run();
186+
this->driver_run(GlobalC::ucell);
187187

188188
ModuleBase::timer::finish(GlobalV::ofs_running);
189189
ModuleBase::Memory::print_all(GlobalV::ofs_running);

source/driver.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef DRIVER_H
22
#define DRIVER_H
33

4+
#include "module_cell/unitcell.h"
5+
46
class Driver
57
{
68
public:
@@ -34,7 +36,7 @@ class Driver
3436
void atomic_world();
3537

3638
// the actual calculations
37-
void driver_run();
39+
void driver_run(UnitCell& ucell);
3840
};
3941

4042
#endif

source/driver_run.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
* the configuration-changing subroutine takes force and stress and updates the
2525
* configuration
2626
*/
27-
void Driver::driver_run() {
27+
void Driver::driver_run(UnitCell& ucell)
28+
{
2829
ModuleBase::TITLE("Driver", "driver_line");
2930
ModuleBase::timer::tick("Driver", "driver_line");
3031

@@ -39,37 +40,35 @@ void Driver::driver_run() {
3940
#endif
4041

4142
// the life of ucell should begin here, mohan 2024-05-12
42-
// delete ucell as a GlobalC in near future
43-
GlobalC::ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
44-
Check_Atomic_Stru::check_atomic_stru(GlobalC::ucell,
45-
PARAM.inp.min_dist_coef);
43+
ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
44+
Check_Atomic_Stru::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);
4645

4746
//! 2: initialize the ESolver (depends on a set-up ucell after `setup_cell`)
48-
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, GlobalC::ucell);
47+
ModuleESolver::ESolver* p_esolver = ModuleESolver::init_esolver(PARAM.inp, ucell);
4948

5049
//! 3: initialize Esolver and fill json-structure
51-
p_esolver->before_all_runners(PARAM.inp, GlobalC::ucell);
50+
p_esolver->before_all_runners(PARAM.inp, ucell);
5251

5352
// this Json part should be moved to before_all_runners, mohan 2024-05-12
5453
#ifdef __RAPIDJSON
55-
Json::gen_stru_wrapper(&GlobalC::ucell);
54+
Json::gen_stru_wrapper(&ucell);
5655
#endif
5756

5857
const std::string cal_type = PARAM.inp.calculation;
5958

6059
//! 4: different types of calculations
6160
if (cal_type == "md")
6261
{
63-
Run_MD::md_line(GlobalC::ucell, p_esolver, PARAM);
62+
Run_MD::md_line(ucell, p_esolver, PARAM);
6463
}
6564
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax" || cal_type == "nscf")
6665
{
6766
Relax_Driver rl_driver;
68-
rl_driver.relax_driver(p_esolver);
67+
rl_driver.relax_driver(p_esolver, ucell);
6968
}
7069
else if (cal_type == "get_S")
7170
{
72-
p_esolver->runner(0, GlobalC::ucell);
71+
p_esolver->runner(0, ucell);
7372
}
7473
else
7574
{
@@ -79,11 +78,11 @@ void Driver::driver_run() {
7978
//! test_neighbour(LCAO),
8079
//! gen_bessel(PW), et al.
8180
const int istep = 0;
82-
p_esolver->others(istep);
81+
p_esolver->others(istep, ucell);
8382
}
8483

8584
//! 5: clean up esolver
86-
p_esolver->after_all_runners();
85+
p_esolver->after_all_runners(ucell);
8786

8887
ModuleESolver::clean_esolver(p_esolver);
8988

source/module_base/opt_TN.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#ifndef OPT_TN_H
22
#define OPT_TN_H
33

4-
#include <limits>
5-
64
#include "./opt_CG.h"
5+
#include "module_cell/unitcell.h"
6+
7+
#include <limits>
78

89
namespace ModuleBase
910
{
@@ -64,10 +65,11 @@ class Opt_TN
6465
double* pgradient, // df(x)/dx
6566
int& flag, // record which truncated condition was triggered, 0 for cond.1, 1 for cond.2, and 2 for cond.3
6667
double* rdirect, // next optimization direction
67-
T* t, // point of class T, which contains the gradient function
68-
void (T::*p_calGradient)(
69-
double* ptemp_x,
70-
double* rtemp_gradient) // a function point, which calculates the gradient at provided x
68+
UnitCell& ucell,
69+
T* t, // point of class T, which contains the gradient function
70+
void (T::*p_calGradient)(double* ptemp_x,
71+
double* rtemp_gradient,
72+
UnitCell& ucell) // a function point, which calculates the gradient at provided x
7173
);
7274

7375
int get_iter()
@@ -128,8 +130,9 @@ void Opt_TN::next_direct(double* px,
128130
double* pgradient,
129131
int& flag,
130132
double* rdirect,
133+
UnitCell& ucell,
131134
T* t,
132-
void (T::*p_calGradient)(double* px, double* rgradient))
135+
void (T::*p_calGradient)(double* px, double* rgradient, UnitCell& ucell))
133136
{
134137
// initialize arrays and parameters
135138
ModuleBase::GlobalFunc::ZEROS(rdirect, this->nx_); // very important
@@ -168,7 +171,7 @@ void Opt_TN::next_direct(double* px,
168171
// epsilon = 1e-9;
169172
for (int i = 0; i < this->nx_; ++i)
170173
temp_x[i] = px[i] + epsilon * cg_direct[i];
171-
(t->*p_calGradient)(temp_x, temp_gradient);
174+
(t->*p_calGradient)(temp_x, temp_gradient, ucell);
172175
for (int i = 0; i < this->nx_; ++i)
173176
temp_Hcgd[i] = (temp_gradient[i] - pgradient[i]) / epsilon;
174177

source/module_cell/unitcell.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,9 @@ void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) {
551551
this->atoms = new Atom[this->ntype]; // atom species.
552552
this->set_atom_flag = true;
553553

554+
this->symm.epsilon = PARAM.inp.symmetry_prec;
555+
this->symm.epsilon_input = PARAM.inp.symmetry_prec;
556+
554557
bool ok = true;
555558
bool ok2 = true;
556559

source/module_esolver/esolver.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,26 @@ class ESolver
2020
}
2121

2222
//! initialize the energy solver by using input parameters and cell modules
23-
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) = 0;
23+
virtual void before_all_runners(const Input_para& inp, UnitCell& ucell) = 0;
2424

2525
//! run energy solver
2626
virtual void runner(const int istep, UnitCell& cell) = 0;
2727

2828
//! perform post processing calculations
29-
virtual void after_all_runners(){};
29+
virtual void after_all_runners(UnitCell& ucell){};
3030

3131
//! deal with exx and other calculation than scf/md/relax/cell-relax:
3232
//! such as nscf, get_wf and get_pchg
33-
virtual void others(const int istep){};
33+
virtual void others(const int istep, UnitCell& ucell){};
3434

3535
//! calculate total energy of a given system
3636
virtual double cal_energy() = 0;
3737

3838
//! calcualte forces for the atoms in the given cell
39-
virtual void cal_force(ModuleBase::matrix& force) = 0;
39+
virtual void cal_force(ModuleBase::matrix& force, UnitCell& ucell) = 0;
4040

4141
//! calcualte stress of given cell
42-
virtual void cal_stress(ModuleBase::matrix& stress) = 0;
42+
virtual void cal_stress(ModuleBase::matrix& stress, UnitCell& ucell) = 0;
4343

4444
bool conv_esolver = true; // whether esolver is converged
4545

source/module_esolver/esolver_dp.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ double ESolver_DP::cal_energy()
127127
return dp_potential;
128128
}
129129

130-
void ESolver_DP::cal_force(ModuleBase::matrix& force)
130+
void ESolver_DP::cal_force(ModuleBase::matrix& force, UnitCell& ucell)
131131
{
132132
force = dp_force;
133133
ModuleIO::print_force(GlobalV::ofs_running, *ucell_, "TOTAL-FORCE (eV/Angstrom)", force, false);
134134
}
135135

136-
void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
136+
void ESolver_DP::cal_stress(ModuleBase::matrix& stress, UnitCell& ucell)
137137
{
138138
stress = dp_virial;
139139

@@ -148,7 +148,7 @@ void ESolver_DP::cal_stress(ModuleBase::matrix& stress)
148148
ModuleIO::print_stress("TOTAL-STRESS", stress, true, false);
149149
}
150150

151-
void ESolver_DP::after_all_runners()
151+
void ESolver_DP::after_all_runners(UnitCell& ucell)
152152
{
153153
GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl;
154154
GlobalV::ofs_running << std::setprecision(16);

source/module_esolver/esolver_dp.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ESolver_DP : public ESolver
3636
* @param inp input parameters
3737
* @param cell unitcell information
3838
*/
39-
void before_all_runners(const Input_para& inp, UnitCell& cell) override;
39+
void before_all_runners(const Input_para& inp, UnitCell& ucell) override;
4040

4141
/**
4242
* @brief Run the DP solver for a given ion/md step and unit cell
@@ -59,21 +59,21 @@ class ESolver_DP : public ESolver
5959
*
6060
* @param force the computed atomic forces
6161
*/
62-
void cal_force(ModuleBase::matrix& force) override;
62+
void cal_force(ModuleBase::matrix& force, UnitCell& ucell) override;
6363

6464
/**
6565
* @brief get the computed lattice virials
6666
*
6767
* @param stress the computed lattice virials
6868
*/
69-
void cal_stress(ModuleBase::matrix& stress) override;
69+
void cal_stress(ModuleBase::matrix& stress, UnitCell& ucell) override;
7070

7171
/**
7272
* @brief Prints the final total energy of the DP model to the output file
7373
*
7474
* This function prints the final total energy of the DP model in eV to the output file along with some formatting.
7575
*/
76-
void after_all_runners() override;
76+
void after_all_runners(UnitCell& ucell) override;
7777

7878
private:
7979
/**

0 commit comments

Comments
 (0)