diff --git a/source/Makefile.Objects b/source/Makefile.Objects index db5585dd3a..8d4d9933a1 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -580,6 +580,7 @@ OBJS_IO=input_conv.o\ output_mat_sparse.o\ ctrl_output_lcao.o\ ctrl_output_fp.o\ + ctrl_output_pw.o\ para_json.o\ abacusjson.o\ general_info.o\ diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index a073bdab91..7b8dd0c6ae 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -11,16 +11,6 @@ namespace base_device { -// struct CPU; -// struct GPU; - -// enum AbacusDevice_t -// { -// UnKnown, -// CpuDevice, -// GpuDevice -// }; - template base_device::AbacusDevice_t get_device_type(const Device* dev); @@ -73,7 +63,6 @@ int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD); int stringCmp(const void* a, const void* b); #ifdef __CUDA - int set_device_by_rank(const MPI_Comm mpi_comm = MPI_COMM_WORLD); #endif @@ -122,4 +111,4 @@ static __inline__ __device__ double atomicAdd(double* address, double val) } #endif -#endif // MODULE_DEVICE_H_ \ No newline at end of file +#endif // MODULE_DEVICE_H_ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 2c342bfa13..daba4282c5 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -1,6 +1,5 @@ #include "esolver_ks_pw.h" -#include "source_base/formatter.h" #include "source_base/global_variable.h" #include "source_base/kernels/math_kernel_op.h" #include "source_base/memory.h" @@ -14,31 +13,17 @@ #include "source_hsolver/diago_iter_assist.h" #include "source_hsolver/hsolver_pw.h" #include "source_hsolver/kernels/dngvd_op.h" -#include "source_io/berryphase.h" -#include "source_io/cal_ldos.h" -#include "source_io/get_pchg_pw.h" -#include "source_io/get_wf_pw.h" #include "source_io/module_parameter/parameter.h" -#include "source_io/numerical_basis.h" -#include "source_io/numerical_descriptor.h" -#include "source_io/to_wannier90_pw.h" -#include "source_io/write_dos_pw.h" -#include "source_io/write_wfc_pw.h" #include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_pw/module_pwdft/onsite_projector.h" #include "source_lcao/module_dftu/dftu.h" #include "source_pw/module_pwdft/VSep_in_pw.h" -#include "source_pw/module_pwdft/elecond.h" #include "source_pw/module_pwdft/forces.h" #include "source_pw/module_pwdft/hamilt_pw.h" -#include "source_pw/module_pwdft/onsite_projector.h" #include "source_pw/module_pwdft/stress_pw.h" #include -#ifdef __MLALGO -#include "source_io/write_mlkedf_descriptors.h" -#endif - #include #include @@ -48,6 +33,8 @@ #include +#include "source_io/ctrl_output_pw.h" // mohan add 20250927 + namespace ModuleESolver { @@ -62,7 +49,6 @@ ESolver_KS_PW::ESolver_KS_PW() template ESolver_KS_PW::~ESolver_KS_PW() { - // delete Hamilt this->deallocate_hamilt(); @@ -502,6 +488,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste skip_solve = true; } } + if (!skip_solve) { hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, @@ -634,54 +621,6 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } - //---------------------------------------------------------- - // 3) Print out electronic wavefunctions in pw basis - // we only print information every few ionic steps - //---------------------------------------------------------- - - // if istep_in = -1, istep will not appear in file name - // if iter_in = -1, iter will not appear in file name - int istep_in = -1; - int iter_in = -1; - bool out_wfc_flag = false; - if (PARAM.inp.out_freq_ion>0) // default value of out_freq_ion is 0 - { - if (istep % PARAM.inp.out_freq_ion == 0) - { - if(iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) - { - istep_in = istep; - iter_in = iter; - out_wfc_flag = true; - } - } - } - else if(iter == PARAM.inp.scf_nmax || conv_esolver) - { - out_wfc_flag = true; - } - - - if (out_wfc_flag) - { - ModuleIO::write_wfc_pw(istep_in, iter_in, - GlobalV::KPAR, - GlobalV::MY_POOL, - GlobalV::MY_RANK, - PARAM.inp.nbands, - PARAM.inp.nspin, - PARAM.globalv.npol, - GlobalV::RANK_IN_POOL, - GlobalV::NPROC_IN_POOL, - PARAM.inp.out_wfc_pw, - PARAM.inp.ecutwfc, - PARAM.globalv.global_out_dir, - this->psi[0], - this->kv, - this->pw_wfc, - GlobalV::ofs_running); - } - //---------------------------------------------------------- // 4) check if oscillate for delta_spin method //---------------------------------------------------------- @@ -699,6 +638,9 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } } + + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + this->kv, this->pw_wfc, PARAM.inp); } template @@ -733,144 +675,9 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->psi[0].size()); } - //---------------------------------------------------------- - //! 4) Compute density of states (DOS) - //---------------------------------------------------------- - if (PARAM.inp.out_dos) - { - bool out_dos_tmp = false; - - int istep_in = -1; - - // default value of out_freq_ion is 0 - if(PARAM.inp.out_freq_ion==0) - { - out_dos_tmp = true; - } - else if (PARAM.inp.out_freq_ion>0) - { - if (istep % PARAM.inp.out_freq_ion == 0) - { - out_dos_tmp = true; - istep_in=istep; - } - else - { - out_dos_tmp = false; - } - } - else - { - out_dos_tmp = false; - } - - // the above is only valid for KSDFT, not SDFT - // this part needs update in the near future - if (PARAM.inp.esolver_type == "sdft") - { - out_dos_tmp = false; - } - - if(out_dos_tmp) - { - ModuleIO::write_dos_pw(ucell, - this->pelec->ekb, - this->pelec->wg, - this->kv, - PARAM.inp.nbands, - istep_in, - this->pelec->eferm, - PARAM.inp.dos_edelta_ev, - PARAM.inp.dos_scale, - PARAM.inp.dos_sigma, - GlobalV::ofs_running); - } - } - - //------------------------------------------------------------------ - // 5) calculate band-decomposed (partial) charge density in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.out_pchg.size() > 0) - { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); - - ModuleIO::get_pchg_pw(PARAM.inp.out_pchg, - this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, - this->chr.ngmc, - &ucell, - this->__kspw_psi, - this->pw_rhod, - this->pw_wfc, - this->ctx, - this->Pgrid, - PARAM.globalv.global_out_dir, - PARAM.inp.if_separate_k, - this->kv, - GlobalV::KPAR, - GlobalV::MY_POOL, - &this->chr); - } - - //------------------------------------------------------------------ - //! 6) calculate Wannier functions in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90) - { - std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation"); - toWannier90_PW wan(PARAM.inp.out_wannier_mmn, - PARAM.inp.out_wannier_amn, - PARAM.inp.out_wannier_unk, - PARAM.inp.out_wannier_eig, - PARAM.inp.out_wannier_wvfn_formatted, - PARAM.inp.nnkpfile, - PARAM.inp.wannier_spin); - wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, this->pelec->ekb, this->pw_wfc, this->pw_big, this->kv, this->psi); - std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); - } - - //------------------------------------------------------------------ - //! 7) calculate Berry phase polarization in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) - { - std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); - berryphase bp; - bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, this->kv); - std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); - } - - //------------------------------------------------------------------ - // 8) write spin constrian results in pw basis - // spin constrain calculations, write atomic magnetization and magnetic force. - //------------------------------------------------------------------ - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - sc.cal_mi_pw(); - sc.print_Mag_Force(GlobalV::ofs_running); - } - - //------------------------------------------------------------------ - // 9) write onsite occupations for charge and magnetizations - //------------------------------------------------------------------ - if (PARAM.inp.onsite_radius > 0) - { // float type has not been implemented - auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(this->kspw_psi), - this->pelec->wg); - } + ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, + this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -897,18 +704,9 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo : reinterpret_cast, Device>*>(this->kspw_psi); // Calculate forces - ff.cal_force(ucell, - force, - *this->pelec, - this->pw_rhod, - &ucell.symm, - &this->sf, - this->solvent, - &this->locpp, - &this->ppcell, - &this->kv, - this->pw_wfc, - this->__kspw_psi); + ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, + &this->sf, this->solvent, &this->locpp, &this->ppcell, + &this->kv, this->pw_wfc, this->__kspw_psi); } template @@ -925,16 +723,9 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s this->__kspw_psi = PARAM.inp.precision == "single" ? new psi::Psi, Device>(this->kspw_psi[0]) : reinterpret_cast, Device>*>(this->kspw_psi); - ss.cal_stress(stress, - ucell, - this->locpp, - this->ppcell, - this->pw_rhod, - &ucell.symm, - &this->sf, - &this->kv, - this->pw_wfc, - this->__kspw_psi); + + ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod, + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi); // external stress double unit_transform = 0.0; @@ -954,121 +745,11 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) //---------------------------------------------------------- ESolver_KS::after_all_runners(ucell); - //---------------------------------------------------------- - //! 2) Compute LDOS - //---------------------------------------------------------- - if (PARAM.inp.out_ldos[0]) - { - ModuleIO::cal_ldos_pw(reinterpret_cast>*>(this->pelec), - this->psi[0], - this->Pgrid, - ucell); - } - - //---------------------------------------------------------- - //! 3) Calculate the spillage value, - //! which are used to generate numerical atomic orbitals - //---------------------------------------------------------- - if (PARAM.inp.basis_type == "pw" && PARAM.inp.out_spillage) - { - // ! Print out overlap matrices - if (PARAM.inp.out_spillage <= 2) - { - for (int i = 0; i < PARAM.inp.bessel_nao_rcuts.size(); i++) - { - if (GlobalV::MY_RANK == 0) - { - std::cout << "update value: bessel_nao_rcut <- " << std::fixed << PARAM.inp.bessel_nao_rcuts[i] - << " a.u." << std::endl; - } - Numerical_Basis numerical_basis; - numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc, ucell, i); - } - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); - } - } - - //---------------------------------------------------------- - //! 4) Print out electronic wave functions in real space - //---------------------------------------------------------- - if (PARAM.inp.out_wfc_norm.size() > 0 || PARAM.inp.out_wfc_re_im.size() > 0) - { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); - - ModuleIO::get_wf_pw(PARAM.inp.out_wfc_norm, - PARAM.inp.out_wfc_re_im, - this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, - &ucell, - this->__kspw_psi, - this->pw_wfc, - this->ctx, - this->Pgrid, - PARAM.globalv.global_out_dir, - this->kv, - GlobalV::KPAR, - GlobalV::MY_POOL); - } - - //---------------------------------------------------------- - //! 5) Use Kubo-Greenwood method to compute conductivities - //---------------------------------------------------------- - if (PARAM.inp.cal_cond) - { - EleCond elec_cond(&ucell, &this->kv, this->pelec, this->pw_wfc, this->kspw_psi, &this->ppcell); - elec_cond.KG(PARAM.inp.cond_smear, - PARAM.inp.cond_fwhm, - PARAM.inp.cond_wcut, - PARAM.inp.cond_dw, - PARAM.inp.cond_dt, - PARAM.inp.cond_nonlocal, - this->pelec->wg); - } - -#ifdef __MLALGO - //---------------------------------------------------------- - //! 7) generate training data for ML-KEDF - //---------------------------------------------------------- - if (PARAM.inp.of_ml_gene_data == 1) - { - this->pelec->pot->update_from_charge(&this->chr, &ucell); + ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi, + this->kspw_psi, this->__kspw_psi, this->sf, + this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); - ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc; - write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx, - PARAM.inp.nelec, - PARAM.inp.of_tf_weight, - PARAM.inp.of_vw_weight, - PARAM.inp.of_ml_chi_p, - PARAM.inp.of_ml_chi_q, - PARAM.inp.of_ml_chi_xi, - PARAM.inp.of_ml_chi_pnl, - PARAM.inp.of_ml_chi_qnl, - PARAM.inp.of_ml_nkernel, - PARAM.inp.of_ml_kernel, - PARAM.inp.of_ml_kernel_scaling, - PARAM.inp.of_ml_yukawa_alpha, - PARAM.inp.of_ml_kernel_file, - ucell.omega, - this->pw_rho); - - write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - this->kspw_psi, - this->pelec, - this->pw_wfc, - this->pw_rho, - ucell, - this->pelec->pot->get_effective_v(0)); - } -#endif } template class ESolver_KS_PW, base_device::DEVICE_CPU>; diff --git a/source/source_io/CMakeLists.txt b/source/source_io/CMakeLists.txt index ffa7598cd9..e511ffb35c 100644 --- a/source/source_io/CMakeLists.txt +++ b/source/source_io/CMakeLists.txt @@ -1,6 +1,7 @@ list(APPEND objects input_conv.cpp ctrl_output_fp.cpp + ctrl_output_pw.cpp bessel_basis.cpp cal_test.cpp cal_dos.cpp diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp new file mode 100644 index 0000000000..a8c588996a --- /dev/null +++ b/source/source_io/ctrl_output_pw.cpp @@ -0,0 +1,541 @@ +#include "source_io/ctrl_output_pw.h" + +#include "source_io/write_wfc_pw.h" // use write_wfc_pw +#include "source_io/write_dos_pw.h" // use write_dos_pw +#include "source_io/to_wannier90_pw.h" // wannier90 interface +#include "source_pw/module_pwdft/onsite_projector.h" // use projector +#include "source_io/numerical_basis.h" +#include "source_io/numerical_descriptor.h" +#include "source_io/cal_ldos.h" +#include "source_io/berryphase.h" +#include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_base/formatter.h" +#include "source_io/get_pchg_pw.h" +#include "source_io/get_wf_pw.h" +#include "source_pw/module_pwdft/elecond.h" + +#ifdef __MLALGO +#include "source_io/write_mlkedf_descriptors.h" +#endif + +void ModuleIO::ctrl_iter_pw(const int istep, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const Input_para& inp) +{ + ModuleBase::TITLE("ModuleIO", "ctrl_iter_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_iter_pw"); + //---------------------------------------------------------- + // 3) Print out electronic wavefunctions in pw basis + // we only print information every few ionic steps + //---------------------------------------------------------- + + // if istep_in = -1, istep will not appear in file name + // if iter_in = -1, iter will not appear in file name + int istep_in = -1; + int iter_in = -1; + bool out_wfc_flag = false; + if (inp.out_freq_ion>0) // default value of out_freq_ion is 0 + { + if (istep % inp.out_freq_ion == 0) + { + if(iter % inp.out_freq_elec == 0 || iter == inp.scf_nmax || conv_esolver) + { + istep_in = istep; + iter_in = iter; + out_wfc_flag = true; + } + } + } + else if(iter == inp.scf_nmax || conv_esolver) + { + out_wfc_flag = true; + } + + if (out_wfc_flag) + { + ModuleIO::write_wfc_pw(istep_in, iter_in, + GlobalV::KPAR, + GlobalV::MY_POOL, + GlobalV::MY_RANK, + inp.nbands, + inp.nspin, + PARAM.globalv.npol, + GlobalV::RANK_IN_POOL, + GlobalV::NPROC_IN_POOL, + inp.out_wfc_pw, + inp.ecutwfc, + PARAM.globalv.global_out_dir, + psi[0], + kv, + pw_wfc, + GlobalV::ofs_running); + } + + ModuleBase::timer::tick("ModuleIO", "ctrl_iter_pw"); + return; +} + + +template +void ModuleIO::ctrl_scf_pw(const int istep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + const Device* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp) +{ + ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + + //---------------------------------------------------------- + //! 4) Compute density of states (DOS) + //---------------------------------------------------------- + if (inp.out_dos) + { + bool out_dos_tmp = false; + + int istep_in = -1; + + // default value of out_freq_ion is 0 + if(inp.out_freq_ion==0) + { + out_dos_tmp = true; + } + else if (inp.out_freq_ion>0) + { + if (istep % inp.out_freq_ion == 0) + { + out_dos_tmp = true; + istep_in=istep; + } + else + { + out_dos_tmp = false; + } + } + else + { + out_dos_tmp = false; + } + + // the above is only valid for KSDFT, not SDFT + // Needs update in the near future + if (inp.esolver_type == "sdft") + { + out_dos_tmp = false; + } + + if(out_dos_tmp) + { + ModuleIO::write_dos_pw(ucell, + pelec->ekb, + pelec->wg, + kv, + inp.nbands, + istep_in, + pelec->eferm, + inp.dos_edelta_ev, + inp.dos_scale, + inp.dos_sigma, + GlobalV::ofs_running); + } + } + + + //------------------------------------------------------------------ + // 5) calculate band-decomposed (partial) charge density in pw basis + //------------------------------------------------------------------ + if (inp.out_pchg.size() > 0) + { + if (__kspw_psi != nullptr && inp.precision == "single") + { + delete reinterpret_cast, Device>*>(__kspw_psi); + } + + // Refresh __kspw_psi + __kspw_psi = inp.precision == "single" + ? new psi::Psi, Device>(kspw_psi[0]) + : reinterpret_cast, Device>*>(kspw_psi); + + const int nbands = kspw_psi->get_nbands(); + const int ngmc = chr.ngmc; + + ModuleIO::get_pchg_pw(inp.out_pchg, + nbands, + inp.nspin, + pw_rhod->nxyz, + ngmc, + &ucell, + __kspw_psi, + pw_rhod, + pw_wfc, + ctx, + para_grid, + PARAM.globalv.global_out_dir, + inp.if_separate_k, + kv, + GlobalV::KPAR, + GlobalV::MY_POOL, + &chr); + } + + + //------------------------------------------------------------------ + //! 6) calculate Wannier functions in pw basis + //------------------------------------------------------------------ + if (inp.calculation == "nscf" && inp.towannier90) + { + std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation"); + toWannier90_PW wan(inp.out_wannier_mmn, + inp.out_wannier_amn, + inp.out_wannier_unk, + inp.out_wannier_eig, + inp.out_wannier_wvfn_formatted, + inp.nnkpfile, + inp.wannier_spin); + wan.set_tpiba_omega(ucell.tpiba, ucell.omega); + wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi); + std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); + } + + + //------------------------------------------------------------------ + //! 7) calculate Berry phase polarization in pw basis + //------------------------------------------------------------------ + if (inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) + { + std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); + berryphase bp; + bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi, pw_rho, pw_wfc, kv); + std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); + } + + //------------------------------------------------------------------ + // 8) write spin constrian results in pw basis + // spin constrain calculations, write atomic magnetization and magnetic force. + //------------------------------------------------------------------ + if (inp.sc_mag_switch) + { + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + sc.cal_mi_pw(); + sc.print_Mag_Force(GlobalV::ofs_running); + } + + //------------------------------------------------------------------ + // 9) write onsite occupations for charge and magnetizations + //------------------------------------------------------------------ + if (inp.onsite_radius > 0) + { // float type has not been implemented + auto* onsite_p = projectors::OnsiteProjector::get_instance(); + onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), + pelec->wg); + } + + ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + return; +} + +template +void ModuleIO::ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp) +{ + ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); + + //---------------------------------------------------------- + //! 1) Compute LDOS + //---------------------------------------------------------- + if (inp.out_ldos[0]) + { + ModuleIO::cal_ldos_pw(reinterpret_cast>*>(pelec), + psi[0], para_grid, ucell); + } + + //---------------------------------------------------------- + //! 2) Calculate the spillage value, + //! which are used to generate numerical atomic orbitals + //---------------------------------------------------------- + if (inp.basis_type == "pw" && inp.out_spillage) + { + // ! Print out overlap matrices + if (inp.out_spillage <= 2) + { + for (int i = 0; i < inp.bessel_nao_rcuts.size(); i++) + { + if (GlobalV::MY_RANK == 0) + { + std::cout << "update value: bessel_nao_rcut <- " << std::fixed << inp.bessel_nao_rcuts[i] + << " a.u." << std::endl; + } + Numerical_Basis numerical_basis; + numerical_basis.output_overlap(psi[0], sf, kv, pw_wfc, ucell, i); + } + ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); + } + } + + //---------------------------------------------------------- + //! 3) Print out electronic wave functions in real space + //---------------------------------------------------------- + if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) + { + if (__kspw_psi != nullptr && inp.precision == "single") + { + delete reinterpret_cast, Device>*>(__kspw_psi); + } + + // Refresh __kspw_psi + __kspw_psi = inp.precision == "single" + ? new psi::Psi, Device>(kspw_psi[0]) + : reinterpret_cast, Device>*>(kspw_psi); + + ModuleIO::get_wf_pw(inp.out_wfc_norm, + inp.out_wfc_re_im, + kspw_psi->get_nbands(), + inp.nspin, + pw_rhod->nxyz, + &ucell, + __kspw_psi, + pw_wfc, + ctx, + para_grid, + PARAM.globalv.global_out_dir, + kv, + GlobalV::KPAR, + GlobalV::MY_POOL); + } + + //---------------------------------------------------------- + //! 4) Use Kubo-Greenwood method to compute conductivities + //---------------------------------------------------------- + if (inp.cal_cond) + { + using Real = typename GetTypeReal::type; + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); + elec_cond.KG(inp.cond_smear, + inp.cond_fwhm, + inp.cond_wcut, + inp.cond_dw, + inp.cond_dt, + inp.cond_nonlocal, + pelec->wg); + } + +#ifdef __MLALGO + //---------------------------------------------------------- + //! 7) generate training data for ML-KEDF + //---------------------------------------------------------- + if (inp.of_ml_gene_data == 1) + { + pelec->pot->update_from_charge(&chr, &ucell); + + ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc; + write_mlkedf_desc.cal_tool->set_para(chr.nrxx, + inp.nelec, + inp.of_tf_weight, + inp.of_vw_weight, + inp.of_ml_chi_p, + inp.of_ml_chi_q, + inp.of_ml_chi_xi, + inp.of_ml_chi_pnl, + inp.of_ml_chi_qnl, + inp.of_ml_nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, + ucell.omega, + pw_rho); + + write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, + kspw_psi, + pelec, + pw_wfc, + pw_rho, + ucell, + pelec->pot->get_effective_v(0)); + } +#endif + + ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); +} + +// complex + CPU +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + const int nstep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +// complex + CPU +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + const int nstep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +#if ((defined __CUDA) || (defined __ROCM)) +// complex + GPU +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + const int nstep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + const base_device::DEVICE_GPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +// complex + GPU +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + const int nstep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + const base_device::DEVICE_GPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); +#endif + +// complex + CPU +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DEVICE_CPU* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + +// complex + CPU +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DEVICE_CPU* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + +#if ((defined __CUDA) || (defined __ROCM)) +// complex + GPU +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DEVICE_GPU* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + +// complex + GPU +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DEVICE_GPU* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); +#endif diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h new file mode 100644 index 0000000000..87fea245b0 --- /dev/null +++ b/source/source_io/ctrl_output_pw.h @@ -0,0 +1,58 @@ +#ifndef CTRL_OUTPUT_PW_H +#define CTRL_OUTPUT_PW_H + +#include "source_base/module_device/device.h" // use Device +#include "source_psi/psi.h" // define psi +#include "source_estate/elecstate_lcao.h" // use pelec + +namespace ModuleIO +{ + +// print out information in 'iter_finish' in ESolver_KS_PW +void ctrl_iter_pw(const int istep, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const Input_para& inp); + +// print out information in 'after_scf' in ESolver_KS_PW +template +void ctrl_scf_pw(const int istep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + const Device* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +// print out information in 'after_all_runners' in ESolver_KS_PW +template +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + +} +#endif diff --git a/source/source_io/get_pchg_pw.h b/source/source_io/get_pchg_pw.h index 7c2d14fe6a..2a61c77aa3 100644 --- a/source/source_io/get_pchg_pw.h +++ b/source/source_io/get_pchg_pw.h @@ -1,7 +1,8 @@ #ifndef GET_PCHG_PW_H #define GET_PCHG_PW_H -#include "cube_io.h" +#include "source_io/cube_io.h" +#include "source_estate/module_charge/symmetry_rho.h" namespace ModuleIO {