Skip to content

Commit 52457bf

Browse files
zgn-26714Qianruipkudyzheng
authored
Cherry-picked #5775 to resolve the single-precision version error. (#6418)
* Refactor: Use psi_initializer instead of wavefunc (#5775) * use psi_initializer * fix compile * same results of random init * make atomic initialized results right * finish refactor * fix compile * fix compile * fix UTs * update results * update results * update GPU results * update * refactor pw * change 108_PW_RE_PINT_RKS results * update results * remove openmp for random generate * update * remove psi_initializer in Doc * remove omp2 * fix compile * re-push an unadded modification * Fixed the bug of duplicate definitions in the source/module_hsolver/test/test_hsolver_sdft.cpp file cherry-pick #5775 has resolved the issue of single-precision version errors Tests were conducted locally on the previous and subsequent versions of cherry-pick, and it was found that the failed samples of both were almost the same. It was concluded that the problem was not caused by cherry-pick --------- Co-authored-by: Qianrui Liu <[email protected]> Co-authored-by: dyzheng <[email protected]>
1 parent 46b0da1 commit 52457bf

File tree

142 files changed

+1511
-3470
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+1511
-3470
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
- [kpar](#kpar)
1212
- [bndpar](#bndpar)
1313
- [latname](#latname)
14-
- [psi\_initializer](#psi_initializer)
1514
- [init\_wfc](#init_wfc)
1615
- [init\_chg](#init_chg)
1716
- [init\_vel](#init_vel)
@@ -93,6 +92,7 @@
9392
- [scf\_os\_stop](#scf_os_stop)
9493
- [scf\_os\_thr](#scf_os_thr)
9594
- [scf\_os\_ndim](#scf_os_ndim)
95+
- [sc\_os\_ndim](#sc_os_ndim)
9696
- [chg\_extrap](#chg_extrap)
9797
- [lspinorb](#lspinorb)
9898
- [noncolin](#noncolin)
@@ -437,7 +437,7 @@
437437
- [abs\_broadening](#abs_broadening)
438438
- [ri\_hartree\_benchmark](#ri_hartree_benchmark)
439439
- [aims\_nbasis](#aims_nbasis)
440-
- [Reduced Density Matrix Functional Theory](#Reduced-Density-Matrix-Functional-Theory)
440+
- [Reduced Density Matrix Functional Theory](#reduced-density-matrix-functional-theory)
441441
- [rdmft](#rdmft)
442442
- [rdmft\_power\_alpha](#rdmft_power_alpha)
443443

@@ -550,17 +550,6 @@ These variables are used to control general system parameters.
550550
- triclinic: triclinic (14)
551551
- **Default**: none
552552

553-
### psi_initializer
554-
555-
- **Type**: Integer
556-
- **Description**: enable the experimental feature psi_initializer, to support use numerical atomic orbitals initialize wavefunction (`basis_type pw` case).
557-
558-
NOTE: this feature is not well-implemented for `nspin 4` case (closed presently), and cannot use with `calculation nscf`/`esolver_type sdft` cases.
559-
Available options are:
560-
- 0: disable psi_initializer
561-
- 1: enable psi_initializer
562-
- **Default**: 0
563-
564553
### init_wfc
565554

566555
- **Type**: String
@@ -572,8 +561,6 @@ These variables are used to control general system parameters.
572561
- atomic+random: add small random numbers on atomic pseudo-wavefunctions
573562
- file: from binary files `WAVEFUNC*.dat`, which are output by setting [out_wfc_pw](#out_wfc_pw) to `2`.
574563
- random: random numbers
575-
576-
with `psi_initializer 1`, two more options are supported:
577564
- nao: from numerical atomic orbitals. If they are not enough, other wave functions are initialized with random numbers.
578565
- nao+random: add small random numbers on numerical atomic orbitals
579566

source/Makefile.Objects

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ OBJS_PSI=psi.o\
397397

398398
OBJS_PSI_INITIALIZER=psi_initializer.o\
399399
psi_initializer_random.o\
400+
psi_initializer_file.o\
400401
psi_initializer_atomic.o\
401402
psi_initializer_atomic_random.o\
402403
psi_initializer_nao.o\
@@ -493,6 +494,7 @@ OBJS_IO=input_conv.o\
493494
to_wannier90_lcao.o\
494495
fR_overlap.o\
495496
unk_overlap_pw.o\
497+
write_pao.o\
496498
write_wfc_pw.o\
497499
winput.o\
498500
write_cube.o\
@@ -668,8 +670,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
668670
of_stress_pw.o\
669671
symmetry_rho.o\
670672
symmetry_rhog.o\
671-
wavefunc.o\
672-
wf_atomic.o\
673673
psi_init.o\
674674
elecond.o\
675675
sto_tool.o\

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ PW_Basis_K::~PW_Basis_K()
2222
delete[] igl2isz_k;
2323
delete[] igl2ig_k;
2424
delete[] gk2;
25-
delete[] ig2ixyz_k_;
2625
#if defined(__CUDA) || defined(__ROCM)
2726
if (this->device == "gpu") {
2827
if (this->precision == "single") {
@@ -169,6 +168,7 @@ void PW_Basis_K::setupIndGk()
169168
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks);
170169
}
171170
#endif
171+
this->get_ig2ixyz_k();
172172
return;
173173
}
174174

@@ -334,8 +334,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const
334334

335335
void PW_Basis_K::get_ig2ixyz_k()
336336
{
337-
delete[] this->ig2ixyz_k_;
338-
this->ig2ixyz_k_ = new int [this->npwk_max * this->nks];
337+
if (this->device != "gpu")
338+
{
339+
//only GPU need to get ig2ixyz_k
340+
return;
341+
}
342+
int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks];
339343
ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks);
340344
assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily.
341345
for(int ik = 0; ik < this->nks; ++ik)
@@ -348,15 +352,12 @@ void PW_Basis_K::get_ig2ixyz_k()
348352
int ixy = this->is2fftixy[is];
349353
int iy = ixy % this->ny;
350354
int ix = ixy / this->ny;
351-
ig2ixyz_k_[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
355+
ig2ixyz_k_cpu[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
352356
}
353357
}
354-
#if defined(__CUDA) || defined(__ROCM)
355-
if (this->device == "gpu") {
356-
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
357-
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks);
358-
}
359-
#endif
358+
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
359+
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
360+
delete[] ig2ixyz_k_cpu;
360361
}
361362

362363
std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ class PW_Basis_K : public PW_Basis
7171
const bool xprime_in = true
7272
);
7373

74-
void get_ig2ixyz_k();
75-
7674
public:
7775
int nks=0;//number of k points in this pool
7876
ModuleBase::Vector3<double> *kvec_d=nullptr; // Direct coordinates of k points
@@ -88,8 +86,7 @@ class PW_Basis_K : public PW_Basis
8886

8987
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
9088
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
91-
int *ig2ixyz_k=nullptr;
92-
int *ig2ixyz_k_=nullptr;
89+
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz
9390

9491
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
9592

@@ -108,6 +105,8 @@ class PW_Basis_K : public PW_Basis
108105
double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
109106
//create igl2isz_k map array for fft
110107
void setupIndGk();
108+
// get ig2ixyz_k
109+
void get_ig2ixyz_k();
111110
//calculate G+K, it is a private function
112111
ModuleBase::Vector3<double> cal_GplusK_cartesian(const int ik, const int ig) const;
113112

source/module_basis/module_pw/test/test4-4.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,6 @@ TEST_F(PWTEST,test4_4)
213213
}
214214
}
215215

216-
//check getig2ixyz_k
217-
pwtest.get_ig2ixyz_k();
218-
for(int igl = 0; igl < npwk ; ++igl)
219-
{
220-
EXPECT_GE(pwtest.ig2ixyz_k_[igl + ik * pwtest.npwk_max], 0);
221-
}
222-
223216
}
224217
delete []tmp;
225218
delete [] rhor;

source/module_cell/read_atoms.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ int UnitCell::read_atom_species(std::ifstream &ifa, std::ofstream &ofs_running)
9999
||(PARAM.inp.basis_type == "lcao_in_pw")
100100
||(
101101
(PARAM.inp.basis_type == "pw")
102-
&&(PARAM.inp.psi_initializer)
103102
&&(PARAM.inp.init_wfc.substr(0, 3) == "nao")
104103
)
105104
|| PARAM.inp.onsite_radius > 0.0
@@ -455,7 +454,7 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
455454
}
456455
else if(PARAM.inp.basis_type == "pw")
457456
{
458-
if ((PARAM.inp.psi_initializer)&&(PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
457+
if ((PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
459458
{
460459
std::string orbital_file = PARAM.inp.orbital_dir + orbital_fn[it];
461460
this->read_orb_file(it, orbital_file, ofs_running, &(atoms[it]));

source/module_cell/unitcell.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,7 @@ void UnitCell::cal_nwfc(std::ofstream& log) {
693693
// Use localized basis
694694
//=====================
695695
if ((PARAM.inp.basis_type == "lcao") || (PARAM.inp.basis_type == "lcao_in_pw")
696-
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.psi_initializer)
697-
&& (PARAM.inp.init_wfc.substr(0, 3) == "nao")
696+
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.init_wfc.substr(0, 3) == "nao")
698697
&& (PARAM.inp.esolver_type == "ksdft"))) // xiaohui add 2013-09-02
699698
{
700699
ModuleBase::GlobalFunc::AUTO_SET("NBANDS", PARAM.inp.nbands);

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace ModuleESolver
5757
template <typename T>
5858
ESolver_KS_LIP<T>::~ESolver_KS_LIP()
5959
{
60+
delete this->psi_local;
6061
// delete Hamilt
6162
this->deallocate_hamilt();
6263
}
@@ -79,11 +80,22 @@ namespace ModuleESolver
7980
this->p_hamilt = nullptr;
8081
}
8182
}
83+
template <typename T>
84+
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
85+
{
86+
ESolver_KS_PW<T>::before_scf(ucell, istep);
87+
this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
88+
}
8289

8390
template <typename T>
8491
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
8592
{
8693
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
94+
delete this->psi_local;
95+
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
96+
this->p_psi_init->psi_initer->nbands_start(),
97+
this->psi->get_nbasis(),
98+
this->psi->get_ngk_pointer());
8799
#ifdef __EXX
88100
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
89101
|| PARAM.inp.calculation == "cell-relax"
@@ -94,14 +106,14 @@ namespace ModuleESolver
94106
this->exx_lip = std::unique_ptr<Exx_Lip<T>>(new Exx_Lip<T>(GlobalC::exx_info.info_lip,
95107
ucell.symm,
96108
&this->kv,
97-
this->p_wf_init,
109+
this->psi_local,
98110
this->kspw_psi,
99111
this->pw_wfc,
100112
this->pw_rho,
101113
this->sf,
102114
&ucell,
103115
this->pelec));
104-
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_wf_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
116+
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
105117
}
106118
}
107119
#endif
@@ -136,18 +148,8 @@ namespace ModuleESolver
136148
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
137149
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
138150

139-
// It is not a good choice to overload another solve function here, this will spoil the concept of
140-
// multiple inheritance and polymorphism. But for now, we just do it in this way.
141-
// In the future, there will be a series of class ESolver_KS_LCAO_PW, HSolver_LCAO_PW and so on.
142-
std::weak_ptr<psi::Psi<T>> psig = this->p_wf_init->get_psig();
143-
144-
if (psig.expired())
145-
{
146-
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density_single", "psig lifetime is expired");
147-
}
148-
149151
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
150-
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);
152+
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
151153

152154
// add exx
153155
#ifdef __EXX

source/module_esolver/esolver_ks_lcaopw.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace ModuleESolver
2323
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
2424
void after_all_runners(UnitCell& ucell) override;
2525

26+
virtual void before_scf(UnitCell& ucell, const int istep) override;
27+
2628
protected:
2729
virtual void iter_init(UnitCell& ucell, const int istep, const int iter) override;
2830
virtual void iter_finish(UnitCell& ucell, const int istep, int& iter) override;
@@ -35,6 +37,8 @@ namespace ModuleESolver
3537

3638
virtual void allocate_hamilt(const UnitCell& ucell) override;
3739
virtual void deallocate_hamilt() override;
40+
41+
psi::Psi<T, base_device::DEVICE_CPU>* psi_local = nullptr; ///< psi for all local NAOs
3842

3943
#ifdef __EXX
4044
std::unique_ptr<Exx_Lip<T>> exx_lip;

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
115115
}
116116

117117
delete this->psi;
118-
delete this->p_wf_init;
118+
delete this->p_psi_init;
119119
}
120120

121121
template <typename T, typename Device>
@@ -186,26 +186,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
186186
&(this->pelec->f_en.vtxc));
187187
}
188188

189-
//! 7) prepare some parameters for electronic wave functions initilization
190-
this->p_wf_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
191-
PARAM.inp.ks_solver,
192-
PARAM.inp.basis_type,
193-
PARAM.inp.psi_initializer,
194-
this->pw_wfc);
195-
this->p_wf_init->prepare_init(&(this->sf),
196-
&ucell,
197-
1,
198-
#ifdef __MPI
199-
&GlobalC::Pkpoints,
200-
GlobalV::MY_RANK,
201-
#endif
202-
&this->ppcell);
203-
204-
if (this->psi != nullptr)
205-
{
206-
delete this->psi;
207-
this->psi = nullptr;
208-
}
209189

210190
//! initalize local pseudopotential
211191
this->locpp.init_vloc(ucell, this->pw_rhod);
@@ -216,17 +196,19 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
216196
this->ppcell.init_vnl(ucell, this->pw_rhod);
217197
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");
218198

219-
//! Allocate psi
220-
this->p_wf_init->allocate_psi(this->psi,
221-
this->kv.get_nkstot(),
222-
this->kv.get_nks(),
223-
this->kv.ngk.data(),
224-
this->pw_wfc->npwk_max,
225-
&this->sf,
226-
&this->ppcell,
227-
ucell);
228-
229-
assert(this->psi != nullptr);
199+
//! Allocate and initialize psi
200+
this->p_psi_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
201+
PARAM.inp.ks_solver,
202+
PARAM.inp.basis_type,
203+
GlobalV::MY_RANK,
204+
ucell,
205+
this->sf,
206+
GlobalC::Pkpoints,
207+
this->ppcell,
208+
*this->pw_wfc);
209+
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
210+
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
211+
230212
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
231213
? new psi::Psi<T, Device>(this->psi[0])
232214
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
@@ -264,7 +246,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
264246

265247
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
266248

267-
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell, ucell);
249+
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
268250
}
269251
if (ucell.ionic_position_updated)
270252
{
@@ -404,29 +386,11 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
404386
auto* dftu = ModuleDFTU::DFTU::get_instance();
405387
dftu->init(ucell, nullptr, this->kv.get_nks());
406388
}
407-
// after init_rho (in pelec->init_scf), we have rho now.
408-
// before hamilt2density, we update Hk and initialize psi
409-
410-
// before_scf function will be called everytime before scf. However, once
411-
// atomic coordinates changed, structure factor will change, therefore all
412-
// atomwise properties will change. So we need to reinitialize psi every
413-
// time before scf. But for random wavefunction, we dont, because random
414-
// wavefunction is not related to atomic coordinates. What the old strategy
415-
// does is only to initialize for once...
416-
if (((PARAM.inp.init_wfc == "random") && (istep == 0)) || (PARAM.inp.init_wfc != "random"))
417-
{
418-
this->p_wf_init->initialize_psi(this->psi,
419-
this->kspw_psi,
420-
this->p_hamilt,
421-
this->ppcell,
422-
ucell,
423-
GlobalV::ofs_running,
424-
this->already_initpsi);
425-
426-
if (this->already_initpsi == false)
427-
{
428-
this->already_initpsi = true;
429-
}
389+
390+
if (!this->already_initpsi)
391+
{
392+
this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
393+
this->already_initpsi = true;
430394
}
431395

432396
ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");

0 commit comments

Comments
 (0)