Skip to content

Commit 4922569

Browse files
committed
refactor hsolver
1 parent b890a5c commit 4922569

File tree

7 files changed

+17
-57
lines changed

7 files changed

+17
-57
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
381381
//---------------------------------------------------------------------------------------------------------------
382382

383383
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
384-
&this->wf,
385384
PARAM.inp.calculation,
386385
PARAM.inp.basis_type,
387386
PARAM.inp.ks_solver,
@@ -391,8 +390,7 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
391390
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
392391
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
393392
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
394-
hsolver::DiagoIterAssist<T, Device>::need_subspace,
395-
this->init_psi);
393+
hsolver::DiagoIterAssist<T, Device>::need_subspace);
396394

397395
hsolver_pw_obj.solve(this->p_hamilt,
398396
this->kspw_psi[0],

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
219219
// hsolver only exists in this function
220220
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(&this->kv,
221221
this->pw_wfc,
222-
&this->wf,
223222
this->stowf,
224223
this->stoche,
225224
this->p_hamilt_sto,
@@ -232,8 +231,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
232231
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
233232
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
234233
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
235-
hsolver::DiagoIterAssist<T, Device>::need_subspace,
236-
this->init_psi);
234+
hsolver::DiagoIterAssist<T, Device>::need_subspace);
237235

238236
hsolver_pw_sdft_obj.solve(this->p_hamilt,
239237
this->kspw_psi[0],

source/module_hsolver/hsolver_pw.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -261,27 +261,6 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
261261
ModuleBase::TITLE("HSolverPW", "solve");
262262
ModuleBase::timer::tick("HSolverPW", "solve");
263263

264-
//---------------------------------------------------------------------------------------------------------------
265-
//---------------------------------for psi init guess!!!!--------------------------------------------------------
266-
//---------------------------------------------------------------------------------------------------------------
267-
// if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
268-
// {
269-
// for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
270-
// {
271-
// /// update H(k) for each k point
272-
// pHamilt->updateHk(ik);
273-
274-
// /// update psi pointer for each k point
275-
// psi.fix_k(ik);
276-
277-
// /// for psi init guess!!!!
278-
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
279-
// }
280-
// }
281-
//---------------------------------------------------------------------------------------------------------------
282-
//---------------------------------------------------------------------------------------------------------------
283-
//---------------------------------------------------------------------------------------------------------------
284-
285264
this->rank_in_pool = rank_in_pool_in;
286265
this->nproc_in_pool = nproc_in_pool_in;
287266

source/module_hsolver/hsolver_pw.h

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class HSolverPW
2020

2121
public:
2222
HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
23-
wavefunc* pwf_in,
2423
const std::string calculation_type_in,
2524
const std::string basis_type_in,
2625
const std::string method_in,
@@ -30,13 +29,10 @@ class HSolverPW
3029
const int scf_iter_in,
3130
const int diag_iter_max_in,
3231
const double diag_thr_in,
33-
const bool need_subspace_in,
34-
const bool initialed_psi_in)
35-
: wfc_basis(wfc_basis_in), pwf(pwf_in),
36-
calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in),
37-
use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in),
38-
scf_iter(scf_iter_in), diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in),
39-
need_subspace(need_subspace_in), initialed_psi(initialed_psi_in) {};
32+
const bool need_subspace_in)
33+
: wfc_basis(wfc_basis_in), calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in),
34+
use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in), scf_iter(scf_iter_in),
35+
diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in), need_subspace(need_subspace_in){};
4036

4137
/// @brief solve function for pw
4238
/// @param pHamilt interface to hamilt
@@ -65,7 +61,6 @@ class HSolverPW
6561
void output_iterInfo();
6662

6763
ModulePW::PW_Basis_K* wfc_basis;
68-
wavefunc* pwf; // only for diago_PAO_in_pw_k2 func
6964

7065
const std::string calculation_type;
7166
const std::string basis_type;
@@ -74,24 +69,24 @@ class HSolverPW
7469
const bool use_uspp;
7570
const int nspin;
7671

77-
const int scf_iter; // Start from 1
72+
const int scf_iter; // Start from 1
7873
const int diag_iter_max; // max iter times for diagonalization
79-
const double diag_thr; // threshold for diagonalization
74+
const double diag_thr; // threshold for diagonalization
8075

8176
const bool need_subspace; // for cg or dav_subspace
82-
const bool initialed_psi;
8377

8478
protected:
8579
Device* ctx = {};
8680

8781
int rank_in_pool = 0;
8882
int nproc_in_pool = 1;
83+
8984
private:
9085
/// @brief calculate the threshold for iterative-diagonalization for each band
9186
void cal_ethr_band(const double& wk, const double* wg, const double& ethr, std::vector<double>& ethrs);
9287

9388
std::vector<double> ethr_band;
94-
89+
9590
#ifdef USE_PAW
9691
void paw_func_in_kloop(const int ik);
9792

source/module_hsolver/hsolver_pw_sdft.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
1010
{
1111
protected:
1212
using Real = typename GetTypeReal<T>::type;
13+
1314
public:
1415
HSolverPW_SDFT(K_Vectors* pkv,
1516
ModulePW::PW_Basis_K* wfc_basis_in,
16-
wavefunc* pwf_in,
1717
Stochastic_WF<T, Device>& stowf,
1818
StoChe<Real, Device>& stoche,
1919
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto,
@@ -26,10 +26,8 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
2626
const int scf_iter_in,
2727
const int diag_iter_max_in,
2828
const double diag_thr_in,
29-
const bool need_subspace_in,
30-
const bool initialed_psi_in)
29+
const bool need_subspace_in)
3130
: HSolverPW<T, Device>(wfc_basis_in,
32-
pwf_in,
3331
calculation_type_in,
3432
basis_type_in,
3533
method_in,
@@ -39,8 +37,7 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
3937
scf_iter_in,
4038
diag_iter_max_in,
4139
diag_thr_in,
42-
need_subspace_in,
43-
initialed_psi_in)
40+
need_subspace_in)
4441
{
4542
stoiter.init(pkv, wfc_basis_in, stowf, stoche, p_hamilt_sto);
4643
}
@@ -56,6 +53,7 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
5653
const bool skip_charge);
5754

5855
Stochastic_Iter<T, Device> stoiter;
56+
5957
protected:
6058
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
6159
using setmem_var_op = base_device::memory::set_memory_op<Real, Device>;

source/module_hsolver/test/test_hsolver_pw.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class TestHSolverPW : public ::testing::Test {
3737
ModulePW::PW_Basis_K pwbk;
3838
hsolver::HSolverPW<std::complex<float>, base_device::DEVICE_CPU> hs_f
3939
= hsolver::HSolverPW<std::complex<float>, base_device::DEVICE_CPU>(&pwbk,
40-
nullptr,
41-
4240
"scf",
4341
"pw",
4442
"cg",
@@ -48,12 +46,9 @@ class TestHSolverPW : public ::testing::Test {
4846
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
4947
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
5048
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,
51-
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
52-
false);
49+
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace);
5350
hsolver::HSolverPW<std::complex<double>, base_device::DEVICE_CPU> hs_d
5451
= hsolver::HSolverPW<std::complex<double>, base_device::DEVICE_CPU>(&pwbk,
55-
nullptr,
56-
5752
"scf",
5853
"pw",
5954
"cg",
@@ -63,8 +58,7 @@ class TestHSolverPW : public ::testing::Test {
6358
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
6459
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
6560
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,
66-
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
67-
false);
61+
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace);
6862

6963
hamilt::Hamilt<std::complex<double>> hamilt_test_d;
7064
hamilt::Hamilt<std::complex<float>> hamilt_test_f;

source/module_hsolver/test/test_hsolver_sdft.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ class TestHSolverPW_SDFT : public ::testing::Test
183183
= hsolver::HSolverPW_SDFT<std::complex<double>, base_device::DEVICE_CPU>(
184184
&kv,
185185
&pwbk,
186-
&wf,
187186
stowf,
188187
stoche,
189188
p_hamilt_sto,
@@ -196,8 +195,7 @@ class TestHSolverPW_SDFT : public ::testing::Test
196195
hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
197196
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
198197
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,
199-
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
200-
false);
198+
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace);
201199

202200
hamilt::Hamilt<std::complex<double>> hamilt_test_d;
203201

0 commit comments

Comments
 (0)