Skip to content

Commit 34375c8

Browse files
committed
Merge branch 'develop' into ucell23
2 parents c981303 + a2ec5d1 commit 34375c8

File tree

143 files changed

+1506
-3467
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

143 files changed

+1506
-3467
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
- [kpar](#kpar)
1212
- [bndpar](#bndpar)
1313
- [latname](#latname)
14-
- [psi\_initializer](#psi_initializer)
1514
- [init\_wfc](#init_wfc)
1615
- [init\_chg](#init_chg)
1716
- [init\_vel](#init_vel)
@@ -93,6 +92,7 @@
9392
- [scf\_os\_stop](#scf_os_stop)
9493
- [scf\_os\_thr](#scf_os_thr)
9594
- [scf\_os\_ndim](#scf_os_ndim)
95+
- [sc\_os\_ndim](#sc_os_ndim)
9696
- [chg\_extrap](#chg_extrap)
9797
- [lspinorb](#lspinorb)
9898
- [noncolin](#noncolin)
@@ -467,7 +467,7 @@
467467
- [abs\_broadening](#abs_broadening)
468468
- [ri\_hartree\_benchmark](#ri_hartree_benchmark)
469469
- [aims\_nbasis](#aims_nbasis)
470-
- [Reduced Density Matrix Functional Theory](#Reduced-Density-Matrix-Functional-Theory)
470+
- [Reduced Density Matrix Functional Theory](#reduced-density-matrix-functional-theory)
471471
- [rdmft](#rdmft)
472472
- [rdmft\_power\_alpha](#rdmft_power_alpha)
473473

@@ -580,17 +580,6 @@ These variables are used to control general system parameters.
580580
- triclinic: triclinic (14)
581581
- **Default**: none
582582

583-
### psi_initializer
584-
585-
- **Type**: Integer
586-
- **Description**: enable the experimental feature psi_initializer, to support use numerical atomic orbitals initialize wavefunction (`basis_type pw` case).
587-
588-
NOTE: this feature is not well-implemented for `nspin 4` case (closed presently), and cannot use with `calculation nscf`/`esolver_type sdft` cases.
589-
Available options are:
590-
- 0: disable psi_initializer
591-
- 1: enable psi_initializer
592-
- **Default**: 0
593-
594583
### init_wfc
595584

596585
- **Type**: String
@@ -602,8 +591,6 @@ These variables are used to control general system parameters.
602591
- atomic+random: add small random numbers on atomic pseudo-wavefunctions
603592
- file: from binary files `WAVEFUNC*.dat`, which are output by setting [out_wfc_pw](#out_wfc_pw) to `2`.
604593
- random: random numbers
605-
606-
with `psi_initializer 1`, two more options are supported:
607594
- nao: from numerical atomic orbitals. If they are not enough, other wave functions are initialized with random numbers.
608595
- nao+random: add small random numbers on numerical atomic orbitals
609596
- **Default**: atomic

source/Makefile.Objects

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ OBJS_PSI=psi.o\
400400

401401
OBJS_PSI_INITIALIZER=psi_initializer.o\
402402
psi_initializer_random.o\
403+
psi_initializer_file.o\
403404
psi_initializer_atomic.o\
404405
psi_initializer_atomic_random.o\
405406
psi_initializer_nao.o\
@@ -496,6 +497,7 @@ OBJS_IO=input_conv.o\
496497
to_wannier90_lcao.o\
497498
fR_overlap.o\
498499
unk_overlap_pw.o\
500+
write_pao.o\
499501
write_wfc_pw.o\
500502
winput.o\
501503
write_cube.o\
@@ -671,8 +673,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
671673
of_stress_pw.o\
672674
symmetry_rho.o\
673675
symmetry_rhog.o\
674-
wavefunc.o\
675-
wf_atomic.o\
676676
psi_init.o\
677677
elecond.o\
678678
sto_tool.o\

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ PW_Basis_K::~PW_Basis_K()
2222
delete[] igl2isz_k;
2323
delete[] igl2ig_k;
2424
delete[] gk2;
25-
delete[] ig2ixyz_k_;
2625
#if defined(__CUDA) || defined(__ROCM)
2726
if (this->device == "gpu") {
2827
if (this->precision == "single") {
@@ -169,6 +168,7 @@ void PW_Basis_K::setupIndGk()
169168
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks);
170169
}
171170
#endif
171+
this->get_ig2ixyz_k();
172172
return;
173173
}
174174

@@ -334,8 +334,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const
334334

335335
void PW_Basis_K::get_ig2ixyz_k()
336336
{
337-
delete[] this->ig2ixyz_k_;
338-
this->ig2ixyz_k_ = new int [this->npwk_max * this->nks];
337+
if (this->device != "gpu")
338+
{
339+
//only GPU need to get ig2ixyz_k
340+
return;
341+
}
342+
int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks];
339343
ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks);
340344
assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily.
341345
for(int ik = 0; ik < this->nks; ++ik)
@@ -348,15 +352,12 @@ void PW_Basis_K::get_ig2ixyz_k()
348352
int ixy = this->is2fftixy[is];
349353
int iy = ixy % this->ny;
350354
int ix = ixy / this->ny;
351-
ig2ixyz_k_[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
355+
ig2ixyz_k_cpu[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
352356
}
353357
}
354-
#if defined(__CUDA) || defined(__ROCM)
355-
if (this->device == "gpu") {
356-
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
357-
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks);
358-
}
359-
#endif
358+
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
359+
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
360+
delete[] ig2ixyz_k_cpu;
360361
}
361362

362363
std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ class PW_Basis_K : public PW_Basis
7171
const bool xprime_in = true
7272
);
7373

74-
void get_ig2ixyz_k();
75-
7674
public:
7775
int nks=0;//number of k points in this pool
7876
ModuleBase::Vector3<double> *kvec_d=nullptr; // Direct coordinates of k points
@@ -88,8 +86,7 @@ class PW_Basis_K : public PW_Basis
8886

8987
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
9088
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
91-
int *ig2ixyz_k=nullptr;
92-
int *ig2ixyz_k_=nullptr;
89+
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz
9390

9491
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
9592

@@ -108,6 +105,8 @@ class PW_Basis_K : public PW_Basis
108105
double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
109106
//create igl2isz_k map array for fft
110107
void setupIndGk();
108+
// get ig2ixyz_k
109+
void get_ig2ixyz_k();
111110
//calculate G+K, it is a private function
112111
ModuleBase::Vector3<double> cal_GplusK_cartesian(const int ik, const int ig) const;
113112

source/module_basis/module_pw/test/test4-4.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,6 @@ TEST_F(PWTEST,test4_4)
213213
}
214214
}
215215

216-
//check getig2ixyz_k
217-
pwtest.get_ig2ixyz_k();
218-
for(int igl = 0; igl < npwk ; ++igl)
219-
{
220-
EXPECT_GE(pwtest.ig2ixyz_k_[igl + ik * pwtest.npwk_max], 0);
221-
}
222-
223216
}
224217
delete []tmp;
225218
delete [] rhor;

source/module_cell/read_atoms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
9090
}
9191
else if(PARAM.inp.basis_type == "pw")
9292
{
93-
if ((PARAM.inp.psi_initializer)&&(PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
93+
if ((PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
9494
{
9595
std::string orbital_file = PARAM.inp.orbital_dir + orbital_fn[it];
9696
this->read_orb_file(it, orbital_file, ofs_running, &(atoms[it]));

source/module_cell/unitcell.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,7 @@ void UnitCell::cal_nwfc(std::ofstream& log) {
537537
// Use localized basis
538538
//=====================
539539
if ((PARAM.inp.basis_type == "lcao") || (PARAM.inp.basis_type == "lcao_in_pw")
540-
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.psi_initializer)
541-
&& (PARAM.inp.init_wfc.substr(0, 3) == "nao")
540+
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.init_wfc.substr(0, 3) == "nao")
542541
&& (PARAM.inp.esolver_type == "ksdft"))) // xiaohui add 2013-09-02
543542
{
544543
ModuleBase::GlobalFunc::AUTO_SET("NBANDS", PARAM.inp.nbands);

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace ModuleESolver
5757
template <typename T>
5858
ESolver_KS_LIP<T>::~ESolver_KS_LIP()
5959
{
60+
delete this->psi_local;
6061
// delete Hamilt
6162
this->deallocate_hamilt();
6263
}
@@ -79,11 +80,22 @@ namespace ModuleESolver
7980
this->p_hamilt = nullptr;
8081
}
8182
}
83+
template <typename T>
84+
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
85+
{
86+
ESolver_KS_PW<T>::before_scf(ucell, istep);
87+
this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
88+
}
8289

8390
template <typename T>
8491
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
8592
{
8693
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
94+
delete this->psi_local;
95+
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
96+
this->p_psi_init->psi_initer->nbands_start(),
97+
this->psi->get_nbasis(),
98+
this->psi->get_ngk_pointer());
8799
#ifdef __EXX
88100
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
89101
|| PARAM.inp.calculation == "cell-relax"
@@ -94,14 +106,14 @@ namespace ModuleESolver
94106
this->exx_lip = std::unique_ptr<Exx_Lip<T>>(new Exx_Lip<T>(GlobalC::exx_info.info_lip,
95107
ucell.symm,
96108
&this->kv,
97-
this->p_wf_init,
109+
this->psi_local,
98110
this->kspw_psi,
99111
this->pw_wfc,
100112
this->pw_rho,
101113
this->sf,
102114
&ucell,
103115
this->pelec));
104-
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_wf_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
116+
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
105117
}
106118
}
107119
#endif
@@ -136,18 +148,8 @@ namespace ModuleESolver
136148
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
137149
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
138150

139-
// It is not a good choice to overload another solve function here, this will spoil the concept of
140-
// multiple inheritance and polymorphism. But for now, we just do it in this way.
141-
// In the future, there will be a series of class ESolver_KS_LCAO_PW, HSolver_LCAO_PW and so on.
142-
std::weak_ptr<psi::Psi<T>> psig = this->p_wf_init->get_psig();
143-
144-
if (psig.expired())
145-
{
146-
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density_single", "psig lifetime is expired");
147-
}
148-
149151
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
150-
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);
152+
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
151153

152154
// add exx
153155
#ifdef __EXX

source/module_esolver/esolver_ks_lcaopw.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace ModuleESolver
2323
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
2424
void after_all_runners(UnitCell& ucell) override;
2525

26+
virtual void before_scf(UnitCell& ucell, const int istep) override;
27+
2628
protected:
2729
virtual void iter_init(UnitCell& ucell, const int istep, const int iter) override;
2830
virtual void iter_finish(UnitCell& ucell, const int istep, int& iter) override;
@@ -35,6 +37,8 @@ namespace ModuleESolver
3537

3638
virtual void allocate_hamilt(const UnitCell& ucell) override;
3739
virtual void deallocate_hamilt() override;
40+
41+
psi::Psi<T, base_device::DEVICE_CPU>* psi_local = nullptr; ///< psi for all local NAOs
3842

3943
#ifdef __EXX
4044
std::unique_ptr<Exx_Lip<T>> exx_lip;

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
118118
}
119119

120120
delete this->psi;
121-
delete this->p_wf_init;
121+
delete this->p_psi_init;
122122
}
123123

124124
template <typename T, typename Device>
@@ -189,26 +189,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
189189
&(this->pelec->f_en.vtxc));
190190
}
191191

192-
//! 7) prepare some parameters for electronic wave functions initilization
193-
this->p_wf_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
194-
PARAM.inp.ks_solver,
195-
PARAM.inp.basis_type,
196-
PARAM.inp.psi_initializer,
197-
this->pw_wfc);
198-
this->p_wf_init->prepare_init(&(this->sf),
199-
&ucell,
200-
1,
201-
#ifdef __MPI
202-
&GlobalC::Pkpoints,
203-
GlobalV::MY_RANK,
204-
#endif
205-
&this->ppcell);
206-
207-
if (this->psi != nullptr)
208-
{
209-
delete this->psi;
210-
this->psi = nullptr;
211-
}
212192

213193
//! initalize local pseudopotential
214194
this->locpp.init_vloc(ucell, this->pw_rhod);
@@ -219,17 +199,19 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
219199
this->ppcell.init_vnl(ucell, this->pw_rhod);
220200
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");
221201

222-
//! Allocate psi
223-
this->p_wf_init->allocate_psi(this->psi,
224-
this->kv.get_nkstot(),
225-
this->kv.get_nks(),
226-
this->kv.ngk.data(),
227-
this->pw_wfc->npwk_max,
228-
&this->sf,
229-
&this->ppcell,
230-
ucell);
231-
232-
assert(this->psi != nullptr);
202+
//! Allocate and initialize psi
203+
this->p_psi_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
204+
PARAM.inp.ks_solver,
205+
PARAM.inp.basis_type,
206+
GlobalV::MY_RANK,
207+
ucell,
208+
this->sf,
209+
GlobalC::Pkpoints,
210+
this->ppcell,
211+
*this->pw_wfc);
212+
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
213+
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
214+
233215
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
234216
? new psi::Psi<T, Device>(this->psi[0])
235217
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
@@ -267,7 +249,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
267249

268250
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
269251

270-
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell, ucell);
252+
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
271253
}
272254
if (ucell.ionic_position_updated)
273255
{
@@ -407,29 +389,11 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
407389
auto* dftu = ModuleDFTU::DFTU::get_instance();
408390
dftu->init(ucell, nullptr, this->kv.get_nks());
409391
}
410-
// after init_rho (in pelec->init_scf), we have rho now.
411-
// before hamilt2density, we update Hk and initialize psi
412-
413-
// before_scf function will be called everytime before scf. However, once
414-
// atomic coordinates changed, structure factor will change, therefore all
415-
// atomwise properties will change. So we need to reinitialize psi every
416-
// time before scf. But for random wavefunction, we dont, because random
417-
// wavefunction is not related to atomic coordinates. What the old strategy
418-
// does is only to initialize for once...
419-
if (((PARAM.inp.init_wfc == "random") && (istep == 0)) || (PARAM.inp.init_wfc != "random"))
420-
{
421-
this->p_wf_init->initialize_psi(this->psi,
422-
this->kspw_psi,
423-
this->p_hamilt,
424-
this->ppcell,
425-
ucell,
426-
GlobalV::ofs_running,
427-
this->already_initpsi);
428-
429-
if (this->already_initpsi == false)
430-
{
431-
this->already_initpsi = true;
432-
}
392+
393+
if (!this->already_initpsi)
394+
{
395+
this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
396+
this->already_initpsi = true;
433397
}
434398

435399
ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");

0 commit comments

Comments
 (0)