Skip to content

Commit 3c874a1

Browse files
committed
change sto_hchi to hamilt_sdft_pw
1 parent 66c4e58 commit 3c874a1

File tree

18 files changed

+227
-350
lines changed

18 files changed

+227
-350
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ OBJS_GINT=gint.o\
290290
init_orb.o\
291291

292292
OBJS_HAMILT=hamilt_pw.o\
293+
hamilt_sdft_pw.o\
293294
operator.o\
294295
operator_pw.o\
295296
ekinetic_pw.o\
@@ -651,7 +652,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
651652
structure_factor_k.o\
652653
soc.o\
653654
sto_iter.o\
654-
sto_hchi.o\
655655
sto_che.o\
656656
sto_wf.o\
657657
sto_func.o\

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
133133
void ESolver_SDFT_PW::before_scf(const int istep)
134134
{
135135
ESolver_KS_PW::before_scf(istep);
136+
delete reinterpret_cast<hamilt::HamiltPW<double>*>(this->p_hamilt);
137+
this->p_hamilt = new hamilt::HamiltSdftPW<std::complex<double>>(this->pelec->pot,
138+
this->pw_wfc,
139+
&this->kv,
140+
PARAM.globalv.npol,
141+
&this->stoche.emin_sto,
142+
&this->stoche.emax_sto);
143+
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>>*>(this->p_hamilt);
144+
136145
if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0)
137146
{
138147
Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto);
@@ -177,7 +186,8 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
177186
this->pw_wfc,
178187
&this->wf,
179188
this->stowf,
180-
this->stoche,
189+
this->stoche,
190+
this->p_hamilt_sto,
181191
PARAM.inp.calculation,
182192
PARAM.inp.basis_type,
183193
PARAM.inp.ks_solver,

source/module_esolver/esolver_sdft_pw.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
#define ESOLVER_SDFT_PW_H
33

44
#include "esolver_ks_pw.h"
5-
#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h"
65
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
76
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
87
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
8+
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"
99

1010
namespace ModuleESolver
1111
{
@@ -27,6 +27,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>>
2727
public:
2828
Stochastic_WF stowf;
2929
StoChe<double> stoche;
30+
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr;
3031

3132
protected:
3233
virtual void before_scf(const int istep) override;

source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
list(APPEND hamilt_stodft_srcs
2+
hamilt_sdft_pw.cpp
23
sto_iter.cpp
3-
sto_hchi.cpp
44
sto_che.cpp
55
sto_wf.cpp
66
sto_func.cpp
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "hamilt_sdft_pw.h"
2+
#include "module_base/timer.h"
3+
4+
namespace hamilt
5+
{
6+
7+
template <typename T, typename Device>
8+
HamiltSdftPW<T, Device>::HamiltSdftPW(elecstate::Potential* pot_in,
9+
ModulePW::PW_Basis_K* wfc_basis,
10+
K_Vectors* p_kv,
11+
const int& npol,
12+
double* emin_in,
13+
double* emax_in)
14+
: HamiltPW<T, Device>(pot_in, wfc_basis, p_kv), ngk(p_kv->ngk)
15+
{
16+
this->classname = "HamiltSdftPW";
17+
this->npwk_max = wfc_basis->npwk_max;
18+
this->npol = npol;
19+
this->emin = emin_in;
20+
this->emax = emax_in;
21+
}
22+
23+
template <typename T, typename Device>
24+
void HamiltSdftPW<T, Device>::hPsi(const T* psi_in, T* hpsi, const int& nbands)
25+
{
26+
auto call_act = [&, this](const Operator<T, Device>* op) -> void {
27+
op->act(nbands, this->npwk_max, this->npol, psi_in, hpsi, this->ngk[op->get_ik()]);
28+
};
29+
30+
ModuleBase::timer::tick("HamiltSdftPW", "hPsi");
31+
ModuleBase::GlobalFunc::ZEROS(hpsi, nbands * this->npwk_max * this->npol);
32+
call_act(this->ops);
33+
Operator<T, Device>* node((Operator<T, Device>*)this->ops->next_op);
34+
while (node != nullptr)
35+
{
36+
call_act(node);
37+
node = (Operator<T, Device>*)(node->next_op);
38+
}
39+
ModuleBase::timer::tick("HamiltSdftPW", "hPsi");
40+
41+
return;
42+
}
43+
44+
template <typename T, typename Device>
45+
void HamiltSdftPW<T, Device>::hPsi_norm(const T* psi_in, T* hpsi_norm, const int& nbands)
46+
{
47+
ModuleBase::timer::tick("HamiltSdftPW", "hPsi_norm");
48+
49+
this->hPsi(psi_in, hpsi_norm, nbands);
50+
51+
const int ik = this->ops->get_ik();
52+
const int npwk_max = this->npwk_max;
53+
const int npwk = this->ngk[ik];
54+
using Real = typename GetTypeReal<T>::type;
55+
const Real emin = *this->emin;
56+
const Real emax = *this->emax;
57+
const Real Ebar = (emin + emax) / 2;
58+
const Real DeltaE = (emax - emin) / 2;
59+
for (int ib = 0; ib < nbands; ++ib)
60+
{
61+
for (int ig = 0; ig < npwk; ++ig)
62+
{
63+
hpsi_norm[ib * npwk_max + ig]
64+
= (hpsi_norm[ib * npwk_max + ig] - Ebar * psi_in[ib * npwk_max + ig]) / DeltaE;
65+
}
66+
}
67+
ModuleBase::timer::tick("HamiltSdftPW", "hPsi_norm");
68+
}
69+
70+
template class HamiltSdftPW<std::complex<float>, base_device::DEVICE_CPU>;
71+
template class HamiltSdftPW<std::complex<double>, base_device::DEVICE_CPU>;
72+
#if ((defined __CUDA) || (defined __ROCM))
73+
template class HamiltSdftPW<std::complex<float>, base_device::DEVICE_GPU>;
74+
template class HamiltSdftPW<std::complex<double>, base_device::DEVICE_GPU>;
75+
#endif
76+
77+
} // namespace hamilt
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#ifndef HAMILTSDFTPW_H
2+
#define HAMILTSDFTPW_H
3+
4+
#include "module_hamilt_pw/hamilt_pwdft/hamilt_pw.h"
5+
6+
namespace hamilt
7+
{
8+
9+
template <typename T, typename Device = base_device::DEVICE_CPU>
10+
class HamiltSdftPW : public HamiltPW<T, Device>
11+
{
12+
public:
13+
/**
14+
* @brief Construct a new Hamilt Sdft P W object
15+
*
16+
* @param pot_in potential
17+
* @param wfc_basis pw basis for wave functions
18+
* @param p_kv k vectors
19+
* @param npol the length of wave function is npol * npwk_max
20+
* @param emin_in Emin of the Hamiltonian
21+
* @param emax_in Emax of the Hamiltonian
22+
*/
23+
HamiltSdftPW(elecstate::Potential* pot_in,
24+
ModulePW::PW_Basis_K* wfc_basis,
25+
K_Vectors* p_kv,
26+
const int& npol,
27+
double* emin_in,
28+
double* emax_in);
29+
/**
30+
* @brief Destroy the Hamilt Sdft P W object
31+
*
32+
*/
33+
~HamiltSdftPW(){};
34+
35+
// void update_emin_emax(const double& emin, const double& emax)
36+
// {
37+
// this->emin = &emin;
38+
// this->emax = &emax;
39+
// }
40+
41+
/**
42+
* @brief Calculate \hat{H}|psi>
43+
*
44+
* @param psi_in input wave function
45+
* @param hpsi output wave function
46+
* @param nbands number of bands
47+
*/
48+
void hPsi(const T* psi_in, T* hpsi, const int& nbands = 1);
49+
50+
/**
51+
* @brief Calculate \hat{H}|psi> and normalize it
52+
*
53+
* @param psi_in input wave function
54+
* @param hpsi output wave function
55+
* @param nbands number of bands
56+
*/
57+
void hPsi_norm(const T* psi_in, T* hpsi, const int& nbands = 1);
58+
59+
double* emin = nullptr; ///< Emin of the Hamiltonian
60+
double* emax = nullptr; ///< Emax of the Hamiltonian
61+
62+
private:
63+
int npwk_max = 0; ///< maximum number of plane waves
64+
int npol = 0; ///< number of polarizations
65+
std::vector<int>& ngk; ///< number of G vectors
66+
};
67+
68+
} // namespace hamilt
69+
70+
#endif

source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in,
2121
this->p_elec = p_elec_in;
2222
this->p_psi = p_psi_in;
2323
this->p_hamilt = p_hamilt_in;
24+
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>>*>(p_hamilt_in);
2425
this->p_stowf = p_stowf_in;
2526
this->nbands_ks = p_psi_in->get_nbands();
2627
this->nbands_sto = p_stowf_in->nchi;
2728
this->method_sto = stoche.method_sto;
28-
this->stohchi.init(p_wfcpw_in, p_kv_in, &stoche.emin_sto, &stoche.emax_sto);
2929
this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto);
3030
}
3131
void Sto_DOS::decide_param(const int& dos_nche,
@@ -44,23 +44,22 @@ void Sto_DOS::decide_param(const int& dos_nche,
4444
this->nbands_sto,
4545
this->p_kv,
4646
this->p_stowf,
47-
this->p_hamilt,
48-
this->stohchi);
47+
this->p_hamilt_sto);
4948
if (dos_setemax)
5049
{
5150
this->emax = dos_emax_ev;
5251
}
5352
else
5453
{
55-
this->emax = *stohchi.Emax * ModuleBase::Ry_to_eV;
54+
this->emax = *p_hamilt_sto->emax * ModuleBase::Ry_to_eV;
5655
}
5756
if (dos_setemin)
5857
{
5958
this->emin = dos_emin_ev;
6059
}
6160
else
6261
{
63-
this->emin = *stohchi.Emin * ModuleBase::Ry_to_eV;
62+
this->emin = *p_hamilt_sto->emin * ModuleBase::Ry_to_eV;
6463
}
6564

6665
if (!dos_setemax && !dos_setemin)
@@ -103,7 +102,6 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
103102
{
104103
this->p_hamilt->updateHk(ik);
105104
}
106-
stohchi.current_ik = ik;
107105
const int npw = p_kv->ngk[ik];
108106
const int nchipk = this->p_stowf->nchip[ik];
109107

@@ -118,11 +116,11 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
118116
p_stowf->chi0->fix_k(ik);
119117
pchi = p_stowf->chi0->get_pointer();
120118
}
121-
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
122-
&stohchi,
123-
std::placeholders::_1,
124-
std::placeholders::_2,
125-
std::placeholders::_3);
119+
auto hchi_norm = std::bind(&hamilt::HamiltSdftPW<std::complex<double>>::hPsi_norm,
120+
p_hamilt_sto,
121+
std::placeholders::_1,
122+
std::placeholders::_2,
123+
std::placeholders::_3);
126124
if (this->method_sto == 1)
127125
{
128126
che.tracepolyA(hchi_norm, pchi, npw, npwx, nchipk);

source/module_hamilt_pw/hamilt_stodft/sto_dos.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#ifndef STO_DOS
22
#define STO_DOS
33
#include "module_elecstate/elecstate.h"
4-
#include "module_hamilt_general/hamilt.h"
4+
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"
55
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
66
#include "module_hamilt_pw/hamilt_stodft/sto_func.h"
7-
#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h"
87
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
98

109
class Sto_DOS
@@ -61,8 +60,9 @@ class Sto_DOS
6160
psi::Psi<std::complex<double>>* p_psi = nullptr; ///< pointer to the wavefunction
6261
hamilt::Hamilt<std::complex<double>>* p_hamilt; ///< pointer to the Hamiltonian
6362
Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions
64-
Stochastic_hchi stohchi; ///< stochastic hchi
6563
Sto_Func<double> stofunc; ///< functions
64+
65+
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT
6666
};
6767

6868
#endif // STO_DOS

source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ Sto_EleCond::Sto_EleCond(UnitCell* p_ucell_in,
2525
: EleCond(p_ucell_in, p_kv_in, p_elec_in, p_wfcpw_in, p_psi_in, p_ppcell_in)
2626
{
2727
this->p_hamilt = p_hamilt_in;
28+
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>>*>(p_hamilt_in);
2829
this->p_stowf = p_stowf_in;
2930
this->nbands_ks = p_psi_in->get_nbands();
3031
this->nbands_sto = p_stowf_in->nchi;
31-
this->stohchi.init(p_wfcpw_in, p_kv_in, &stoche.emin_sto, &stoche.emax_sto);
3232
this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto);
3333
}
3434

@@ -91,24 +91,23 @@ void Sto_EleCond::decide_nche(const double dt,
9191

9292
int nche_new = 0;
9393
loop:
94-
// re-set Emin & Emax both in stohchi & stofunc
94+
// re-set Emin & Emax both in p_hamilt_sto & stofunc
9595
check_che(std::max(nche_old * 2, fd_nche),
9696
try_emin,
9797
try_emax,
9898
this->nbands_sto,
9999
this->p_kv,
100100
this->p_stowf,
101-
this->p_hamilt,
102-
this->stohchi);
101+
this->p_hamilt_sto);
103102

104103
// second try to find nche with new Emin & Emax
105104
getnche(nche_new);
106105

107106
if (nche_new > nche_old * 2)
108107
{
109108
nche_old = nche_new;
110-
try_emin = *stohchi.Emin;
111-
try_emax = *stohchi.Emax;
109+
try_emin = *p_hamilt_sto->emin;
110+
try_emax = *p_hamilt_sto->emax;
112111
goto loop;
113112
}
114113

@@ -177,8 +176,8 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi<std::complex<float>>& kspsi_all,
177176
psi::Psi<std::complex<float>> f_rightchi(1, perbands_sto, npwx, p_kv->ngk.data());
178177
psi::Psi<std::complex<float>> f_right_hchi(1, perbands_sto, npwx, p_kv->ngk.data());
179178

180-
this->stohchi.hchi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto);
181-
this->stohchi.hchi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto);
179+
this->p_hamilt_sto->hPsi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto);
180+
this->p_hamilt_sto->hPsi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto);
182181
convert_psi(rightchi, f_rightchi);
183182
convert_psi(right_hchi, f_right_hchi);
184183
right_hchi.resize(1, 1, 1);
@@ -589,7 +588,6 @@ void Sto_EleCond::sKG(const int& smear_type,
589588
{
590589
this->p_hamilt->updateHk(ik);
591590
}
592-
this->stohchi.current_ik = ik;
593591
const int npw = p_kv->ngk[ik];
594592

595593
// get allbands_ks
@@ -733,8 +731,8 @@ void Sto_EleCond::sKG(const int& smear_type,
733731

734732
auto nroot_fd = std::bind(&Sto_Func<double>::nroot_fd, &this->stofunc, std::placeholders::_1);
735733
che.calcoef_real(nroot_fd);
736-
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
737-
&stohchi,
734+
auto hchi_norm = std::bind(&hamilt::HamiltSdftPW<std::complex<double>>::hPsi_norm,
735+
p_hamilt_sto,
738736
std::placeholders::_1,
739737
std::placeholders::_2,
740738
std::placeholders::_3);

source/module_hamilt_pw/hamilt_stodft/sto_elecond.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ class Sto_EleCond : protected EleCond
6060
int cond_dtbatch = 0; ///< number of time steps in a batch
6161
hamilt::Hamilt<std::complex<double>>* p_hamilt; ///< pointer to the Hamiltonian
6262
Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions
63-
Stochastic_hchi stohchi; ///< stochastic hchi
6463
Sto_Func<double> stofunc; ///< functions
6564

65+
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT
66+
6667
protected:
6768
/**
6869
* @brief calculate Jmatrix <leftv|J|rightv>

0 commit comments

Comments
 (0)