From 1e288f9036b98bb6db499360451ba2726f504355 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Wed, 15 Jan 2025 21:33:19 -0800 Subject: [PATCH 01/14] change npol to private --- source/module_elecstate/elecstate_pw.cpp | 2 +- source/module_hamilt_general/operator.cpp | 6 +++--- source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp | 2 +- .../module_deltaspin/cal_mw_from_lambda.cpp | 4 ++-- source/module_hamilt_lcao/module_dftu/dftu_pw.cpp | 4 ++-- source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 4 ++-- source/module_io/write_vxc_lip.hpp | 2 +- source/module_psi/psi.cpp | 4 ++-- source/module_psi/psi.h | 5 ++++- 10 files changed, 20 insertions(+), 17 deletions(-) diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index f55f2ec447..0b4fbe0368 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -271,7 +271,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) { const T one{1, 0}; const T zero{0, 0}; - const int npol = psi.npol; + const int npol = psi.get_npol(); const int npwx = psi.get_nbasis() / npol; const int nbands = psi.get_nbands() * npol; const int nkb = this->ppcell->nkb; diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 008d5e30e3..3f9e43a99c 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -63,7 +63,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp delete this->hpsi; this->hpsi = new psi::Psi(hpsi_pointer, 1, - nbands / psi_input->npol, + nbands / psi_input->get_npol(), psi_input->get_nbasis(), psi_input->get_nbasis(), true); @@ -86,7 +86,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp default: op->act(nbands, psi_input->get_nbasis(), - psi_input->npol, + psi_input->get_npol(), tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), @@ -105,7 +105,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } ModuleBase::timer::tick("Operator", "hPsi"); - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer); } template diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp index 94c5c74db7..7836368709 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp @@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_CPU>* psi_t = static_cast, base_device::DEVICE_CPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 87a2fa41cc..79f8ba1fff 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); diff --git a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp index cc0c3a6c30..d8bc62f3a5 100644 --- a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp +++ b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp @@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index 2bb69dc131..32a4902221 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -165,7 +165,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_); RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_); //rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq); - rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol); + rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol); // For being compatible with present cal_force and cal_stress framework // uncomment the following code block if you want to use the Onsite_Proj_tools if(this->tab_atomic_ == nullptr) @@ -541,7 +541,7 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psi::before_all_runners(UnitCell& ucell, const Input_p this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max); + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(PARAM.inp.pw_seed); this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single" diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 667b440916..f5f9292522 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -78,13 +78,20 @@ void ESolver_SDFT_PW::before_all_runners(UnitCell& ucell, const Input // 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}> size_t size = stowf.chi0->size(); this->stowf.shchi - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, + true); ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T)); if (PARAM.inp.nbands > 0) { this->stowf.chiortho - = new psi::Psi(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data()); + = new psi::Psi(this->kv.get_nks(), + this->stowf.nchip_max, + this->pw_wfc->npwk_max, + this->kv.ngk, true); ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T)); } diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index 2637fe41d8..0bf61d6947 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -159,7 +159,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_esolver/lcao_others.cpp b/source/module_esolver/lcao_others.cpp index fc0ab246d3..faca3563f0 100644 --- a/source/module_esolver/lcao_others.cpp +++ b/source/module_esolver/lcao_others.cpp @@ -165,7 +165,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) ncol = PARAM.inp.nbands; #endif } - this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->pv.nrow, this->kv.ngk, true); } // init wfc from file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 8a76daa9e9..49e0ab7469 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -35,7 +35,7 @@ template void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) { this->nks = p_kv->get_nks(); - this->ngk = p_kv->ngk.data(); + this->ngk = p_kv->ngk; this->npwx = npwx_in; nchip = new int[nks]; @@ -111,7 +111,7 @@ void Stochastic_WF::allocate_chi0() this->nchip_max = tmpnchip; size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0_cpu = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(T)); for (int ik = 0; ik < nks; ++ik) @@ -123,7 +123,7 @@ void Stochastic_WF::allocate_chi0() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -207,7 +207,7 @@ void Stochastic_WF::init_com_orbitals() delete[] npwip; } size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, this->nchip_max, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -252,7 +252,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { @@ -266,7 +266,7 @@ void Stochastic_WF::init_com_orbitals() const int npwx = this->npwx; const int nks = this->nks; size_t size = this->nchip_max * npwx * nks; - this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk); + this->chi0_cpu = new psi::Psi>(nks, npwx, npwx, this->ngk, true); this->chi0_cpu->zero_out(); ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex)); for (int ik = 0; ik < nks; ++ik) @@ -284,7 +284,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, ture); } else { diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index a423810544..4afdeb4247 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -30,10 +30,10 @@ class Stochastic_WF int* nchip = nullptr; ///< The number of stochatic orbitals in current process of each k point. int nchip_max = 0; ///< Max number of stochastic orbitals among all k points. int nks = 0; ///< number of k-points - int* ngk = nullptr; ///< ngk in klist int npwx = 0; ///< max ngk[ik] in all processors int nbands_diag = 0; ///< number of bands obtained from diagonalization int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag; + std::vector ngk; ///< ngk in klist public: // Tn(H)|chi> psi::Psi* chiallorder = nullptr; diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 2f9a1b4313..44deac1bbd 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -219,7 +219,7 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, k2d.distribute_hsk(pHamilt, ik_kpar, nrow); /// global index of k point int ik_global = ik + k2d.get_pKpoints()->startk_pool[k2d.get_my_pool()]; - auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, nullptr); + auto psi_pool = psi::Psi(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true); ModuleBase::Memory::record("HSolverLCAO::psi_pool", nrow * ncol_bands_pool * sizeof(T)); if (ik_global < psi.get_nk() && ik < k2d.get_pKpoints()->nks_pool[k2d.get_my_pool()]) { diff --git a/source/module_hsolver/test/diago_mock.h b/source/module_hsolver/test/diago_mock.h index e63022f43d..85a7750fc5 100644 --- a/source/module_hsolver/test/diago_mock.h +++ b/source/module_hsolver/test/diago_mock.h @@ -214,7 +214,7 @@ class HPsi { Structure_Factor* sf; int* ngk = nullptr; - psi::Psi psitmp(1, nband, npw, ngk); + psi::Psi psitmp(1, nband, npw, npw, true); for(int i=0;i>(nks, nbands, wfcpw->npwk_max, wfcpw->npwk); + psi = new psi::Psi>(nks, nbands, wfcpw->npwk_max, kv->ngk, true); std::complex* ptr = psi->get_pointer(); for (int i = 0; i < nks * nbands * wfcpw->npwk_max; i++) { diff --git a/source/module_io/to_wannier90_lcao_in_pw.cpp b/source/module_io/to_wannier90_lcao_in_pw.cpp index 78f04a8a97..e067671465 100644 --- a/source/module_io/to_wannier90_lcao_in_pw.cpp +++ b/source/module_io/to_wannier90_lcao_in_pw.cpp @@ -52,7 +52,11 @@ void toWannier90_LCAO_IN_PW::calculate( const int nks_psi = (PARAM.inp.calculation == "nscf" && PARAM.inp.mem_saver == 1)? 1 : wfcpw->nks; const int nks_psig = (PARAM.inp.basis_type == "pw")? 1 : nks_psi; const int nbands_actual = this->psi_initer_->nbands_start(); - this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, nbands_actual, wfcpw->npwk_max*PARAM.globalv.npol, wfcpw->npwk); + this->psi = new psi::Psi, base_device::DEVICE_CPU>(nks_psig, + nbands_actual, + wfcpw->npwk_max*PARAM.globalv.npol, + kv.ngk, + true); read_nnkp(ucell,kv); if (PARAM.inp.nspin == 2) @@ -117,7 +121,11 @@ psi::Psi>* toWannier90_LCAO_IN_PW::get_unk_from_lcao( { // init int npwx = wfcpw->npwk_max; - psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, num_bands, npwx*PARAM.globalv.npol, kv.ngk.data()); + psi::Psi> *unk_inLcao = new psi::Psi>(num_kpts, + num_bands, + npwx*PARAM.globalv.npol, + kv.ngk, + true); unk_inLcao->zero_out(); // Orbital projection to plane wave diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 187e14fb89..dffc242205 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -85,7 +85,7 @@ Psi::Psi(const int nk_in, const bool k_first_in) { assert(nk_in > 0); - assert(nbd_in > 0); + assert(nbd_in >= 0); assert(nbs_in > 0); this->k_first = k_first_in; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index d9b075548d..6c96392c88 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -40,7 +40,7 @@ class Psi Psi(); // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in = true); + Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in); // Constructor 1-2: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index 2cdce4a5a8..28a1fb4c1b 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -106,7 +106,7 @@ void PSIInit::initialize_psi(Psi>* psi, if (not_equal) { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) : reinterpret_cast*>(psi_cpu); } @@ -119,7 +119,7 @@ void PSIInit::initialize_psi(Psi>* psi, } else { - psi_cpu = new Psi(1, nbands_start, nbasis, nullptr); + psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = kspw_psi; } } @@ -203,7 +203,7 @@ void PSIInit::initialize_lcao_in_pw(Psi* psi_local, std::ofstream& } } -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx) +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx) { assert(npwx > 0); assert(nks > 0); @@ -215,7 +215,7 @@ void allocate_psi(Psi>*& psi, const int& nks, const int* ng { nks2 = 1; } - psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk); + psi = new psi::Psi>(nks2, nbands, npwx * PARAM.globalv.npol, ngk, true); const size_t memory_cost = sizeof(std::complex) * nks2 * nbands * (PARAM.globalv.npol * npwx); std::cout << " MEMORY FOR PSI (MB) : " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("Psi_PW", memory_cost); diff --git a/source/module_psi/psi_init.h b/source/module_psi/psi_init.h index e112a71a6e..bf93e534d0 100644 --- a/source/module_psi/psi_init.h +++ b/source/module_psi/psi_init.h @@ -86,7 +86,7 @@ class PSIInit }; ///@brief allocate the wavefunction -void allocate_psi(Psi>*& psi, const int& nks, const int* ngk, const int& nbands, const int& npwx); +void allocate_psi(Psi>*& psi, const int& nks, const std::vector& ngk, const int& nbands, const int& npwx); } // namespace psi #endif \ No newline at end of file diff --git a/source/module_psi/psi_initializer_atomic_random.cpp b/source/module_psi/psi_initializer_atomic_random.cpp index f7b735f5ed..7e0652c25c 100644 --- a/source/module_psi/psi_initializer_atomic_random.cpp +++ b/source/module_psi/psi_initializer_atomic_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_atomic_random::init_psig(T* psig, const int& ik) psi_initializer_atomic::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/psi_initializer_nao_random.cpp b/source/module_psi/psi_initializer_nao_random.cpp index 4f8b8d940f..ab23c4a163 100644 --- a/source/module_psi/psi_initializer_nao_random.cpp +++ b/source/module_psi/psi_initializer_nao_random.cpp @@ -21,7 +21,7 @@ void psi_initializer_nao_random::init_psig(T* psig, const int& ik) psi_initializer_nao::init_psig(psig, ik); const int npol = PARAM.globalv.npol; const int nbasis = this->pw_wfc_->npwk_max * npol; - psi::Psi psi_random(1, this->nbands_start_, nbasis, nullptr); + psi::Psi psi_random(1, this->nbands_start_, nbasis, nbasis, true); psi_random.fix_k(0); this->random_t(psi_random.get_pointer(), 0, this->nbands_start_, ik, 0); for (int iband = 0; iband < this->nbands_start_; iband++) diff --git a/source/module_psi/test/psi_initializer_unit_test.cpp b/source/module_psi/test/psi_initializer_unit_test.cpp index fd9dcd497c..b5b4180b2d 100644 --- a/source/module_psi/test/psi_initializer_unit_test.cpp +++ b/source/module_psi/test/psi_initializer_unit_test.cpp @@ -321,7 +321,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(-0.66187696761064307, psi->operator()(0,0,0).real(), 1e-4); delete psi; @@ -340,7 +340,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomic) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -363,7 +363,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -390,7 +390,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); PARAM.input.nspin = 1; @@ -413,7 +413,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigAtomicRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -432,7 +432,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNao) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -451,7 +451,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoRandom) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -475,7 +475,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSoc) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -499,7 +499,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSo) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; @@ -523,7 +523,7 @@ TEST_F(PsiIntializerUnitTest, CalPsigNaoSocHasSoDOMAG) { this->psi_init->tabulate(); // always: new, initialize, tabulate, allocate, proj_ao_onkG const int nbands_start = this->psi_init->nbands_start(); const int nbasis = this->p_pw_wfc->npwk_max * PARAM.globalv.npol; - psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nullptr); + psi::Psi>* psi = new psi::Psi>(1, nbands_start, nbasis, nbasis, true); this->psi_init->init_psig(psi->get_pointer(), 0); EXPECT_NEAR(0, psi->operator()(0,0,0).real(), 1e-12); delete psi; diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index 0b42df63c7..af6af855b1 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -8,12 +8,12 @@ class TestPsi : public ::testing::Test const int ink = 2; const int inbands = 4; const int inbasis = 10; - int ngk[4] = {10, 10, 10, 10}; + std::vector ngk = {10, 10}; - const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); - const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, &ngk[0]); - const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, &ngk[0]); + const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, ngk, true); + const psi::Psi>* psi_object33 = new psi::Psi>(ink, inbands, inbasis, ngk, true); + const psi::Psi* psi_object34 = new psi::Psi(ink, inbands, inbasis, ngk, true); }; TEST_F(TestPsi, get_val) From 99f10f4109ed3537a28dbd1b146dd2ca2498b9f7 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 16 Jan 2025 04:30:49 -0800 Subject: [PATCH 09/14] fix bug --- source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 49e0ab7469..e3e3cb7cb9 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -284,7 +284,7 @@ void Stochastic_WF::init_com_orbitals() Device* ctx = {}; if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, ture); + this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } else { From f6a220f8eb778e477d8cca81940a102d54dc682c Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 16 Jan 2025 17:57:27 -0800 Subject: [PATCH 10/14] fix test bug --- source/module_psi/test/psi_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index af6af855b1..b57e1b4caf 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -8,7 +8,7 @@ class TestPsi : public ::testing::Test const int ink = 2; const int inbands = 4; const int inbasis = 10; - std::vector ngk = {10, 10}; + std::vector ngk = {10, 10, 10, 10}; const psi::Psi>* psi_object31 = new psi::Psi>(ink, inbands, inbasis, ngk, true); const psi::Psi* psi_object32 = new psi::Psi(ink, inbands, inbasis, ngk, true); From 05e486afeccb80ca51467cec26160e946537dc9e Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 16 Jan 2025 23:22:04 -0800 Subject: [PATCH 11/14] remove Constructor 1-1 --- .../test/ao_to_mo_test.cpp | 72 +++++++++---------- .../module_lr/dm_trans/test/dm_trans_test.cpp | 48 ++++++------- source/module_psi/psi.cpp | 62 ++++++++-------- source/module_psi/psi.h | 4 +- source/module_psi/test/psi_test.cpp | 10 +-- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp index 5601ad451d..8bcb88b525 100644 --- a/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp +++ b/source/module_lr/ao_to_mo_transformer/test/ao_to_mo_test.cpp @@ -64,18 +64,18 @@ TEST_F(AO2MOTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data(), size_v); } @@ -96,18 +96,18 @@ TEST_F(AO2MOTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); - psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_for(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> vo_blas(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_for(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> oo_blas(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_for(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); + psi::Psi> vv_blas(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; int size_v = s.naos * s.naos; for (int istate = 0;istate < nstate;++istate) { std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); std::vector V(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); set_rand(&c(0, 0, 0), size_c); for (auto& v : V) { set_rand(v.data>(), size_v); } @@ -137,7 +137,7 @@ TEST_F(AO2MOTest, DoubleParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); @@ -148,12 +148,12 @@ TEST_F(AO2MOTest, DoubleParallel) EXPECT_GE(s.nvirt, pvo.dim0); EXPECT_GE(s.nocc, pvo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -174,7 +174,7 @@ TEST_F(AO2MOTest, DoubleParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_1(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_1, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data(), V_full.at(isk).data(), false, s.naos, s.naos); @@ -182,13 +182,13 @@ TEST_F(AO2MOTest, DoubleParallel) } if (my_rank == 0) { - psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nocc, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } @@ -208,18 +208,18 @@ TEST_F(AO2MOTest, ComplexParallel) LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, pV.blacs_ctxt); std::vector ngk_temp_1(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), ngk_temp_1, true); Parallel_2D pvo, poo, pvv; LR_Util::setup_2d_division(pvo, s.nb, s.nvirt, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(poo, s.nb, s.nocc, s.nocc, pV.blacs_ctxt); LR_Util::setup_2d_division(pvv, s.nb, s.nvirt, s.nvirt, pV.blacs_ctxt); - psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), nullptr, false); - psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); - psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), nullptr, false); - psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, nullptr, false); - psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), nullptr, false); - psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vo_pblas_loc(s.nks, nstate, pvo.get_local_size(), pvo.get_local_size(), false); + psi::Psi> vo_gather(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); + psi::Psi> oo_pblas_loc(s.nks, nstate, poo.get_local_size(), poo.get_local_size(), false); + psi::Psi> oo_gather(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); + psi::Psi> vv_pblas_loc(s.nks, nstate, pvv.get_local_size(), pvv.get_local_size(), false); + psi::Psi> vv_gather(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); for (int istate = 0;istate < nstate;++istate) { for (int isk = 0;isk < s.nks;++isk) @@ -241,7 +241,7 @@ TEST_F(AO2MOTest, ComplexParallel) // compare to global AX std::vector V_full(s.nks, container::Tensor(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { s.naos, s.naos })); std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { LR_Util::gather_2d_to_full(pV, V.at(isk).data>(), V_full.at(isk).data>(), false, s.naos, s.naos); @@ -249,13 +249,13 @@ TEST_F(AO2MOTest, ComplexParallel) } if (my_rank == 0) { - psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, nullptr, false); + psi::Psi> vo_full_istate(s.nks, 1, s.nocc * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vo_full_istate(0, 0, 0), false); check_eq(&vo_full_istate(0, 0, 0), &vo_gather(istate, 0, 0), s.nks * s.nocc * s.nvirt); - psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, nullptr, false); + psi::Psi> oo_full_istate(s.nks, 1, s.nocc * s.nocc, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nocc, &oo_full_istate(0, 0, 0), false, LR::MO_TYPE::OO); check_eq(&oo_full_istate(0, 0, 0), &oo_gather(istate, 0, 0), s.nks * s.nocc * s.nocc); - psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> vv_full_istate(s.nks, 1, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); LR::ao_to_mo_blas(V_full, c_full, s.nocc, s.nvirt, &vv_full_istate(0, 0, 0), false, LR::MO_TYPE::VV); check_eq(&vv_full_istate(0, 0, 0), &vv_gather(istate, 0, 0), s.nks * s.nvirt * s.nvirt); } diff --git a/source/module_lr/dm_trans/test/dm_trans_test.cpp b/source/module_lr/dm_trans/test/dm_trans_test.cpp index 8a40f08c61..acef1e8a40 100644 --- a/source/module_lr/dm_trans/test/dm_trans_test.cpp +++ b/source/module_lr/dm_trans/test/dm_trans_test.cpp @@ -61,18 +61,18 @@ TEST_F(DMTransTest, DoubleSerial) { for (auto s : this->sizes) { - psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi& X, const LR::MO_TYPE type) { @@ -92,18 +92,18 @@ TEST_F(DMTransTest, ComplexSerial) { for (auto s : this->sizes) { - psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); + psi::Psi> X_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); set_rand(X_vo.get_pointer(), nstate * s.nks * s.nocc * s.nvirt); - psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); + psi::Psi> X_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); set_rand(X_oo.get_pointer(), nstate * s.nks * s.nocc * s.nocc); - psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); + psi::Psi> X_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); set_rand(X_vv.get_pointer(), nstate * s.nks * s.nvirt * s.nvirt); for (int istate = 0;istate < nstate;++istate) { const int size_c = s.nks * (s.nocc + s.nvirt) * s.naos; std::vector temp(s.nks, s.naos); - psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi> c(s.nks, s.nocc + s.nvirt, s.naos, temp, true); set_rand(c.get_pointer(), size_c); auto test = [&](psi::Psi>& X, const LR::MO_TYPE type) { @@ -132,18 +132,18 @@ TEST_F(DMTransTest, DoubleParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp_2(s.nks, pc.get_row_size()); - psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2.data(), true); + psi::Psi c(s.nks, pc.get_col_size(), pc.get_row_size(), temp_2, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); @@ -153,9 +153,9 @@ TEST_F(DMTransTest, DoubleParallel) EXPECT_GE(s.nocc, px_vo.dim1); EXPECT_GE(s.naos, pc.dim0); - psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nocc, false); // allocate X_full + psi::Psi X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nvirt * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi& X, psi::Psi& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -182,7 +182,7 @@ TEST_F(DMTransTest, DoubleParallel) // gather C std::vector temp(s.nks, s.naos); - psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp.data(), true); + psi::Psi c_full(s.nks, s.nocc + s.nvirt, s.naos, temp, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); @@ -223,24 +223,24 @@ TEST_F(DMTransTest, ComplexParallel) LR_Util::setup_2d_division(px_oo, s.nb, s.nocc, s.nocc, px_vo.blacs_ctxt); LR_Util::setup_2d_division(px_vv, s.nb, s.nvirt, s.nvirt, px_vo.blacs_ctxt); - psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), nullptr, false); + psi::Psi> X_vo(s.nks, nstate, px_vo.get_local_size(), px_vo.get_local_size(), false); set_rand(X_vo.get_pointer(), nstate * s.nks * px_vo.get_local_size()); - psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), nullptr, false); + psi::Psi> X_oo(s.nks, nstate, px_oo.get_local_size(), px_oo.get_local_size(), false); set_rand(X_oo.get_pointer(), nstate * s.nks * px_oo.get_local_size()); - psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), nullptr, false); + psi::Psi> X_vv(s.nks, nstate, px_vv.get_local_size(), px_vv.get_local_size(), false); set_rand(X_vv.get_pointer(), nstate * s.nks * px_vv.get_local_size()); Parallel_2D pc; LR_Util::setup_2d_division(pc, s.nb, s.naos, s.nocc + s.nvirt, px_vo.blacs_ctxt); std::vector temp(s.nks, pc.get_row_size()); - psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp.data(), true); + psi::Psi> c(s.nks, pc.get_col_size(), pc.get_row_size(), temp, true); Parallel_2D pmat; LR_Util::setup_2d_division(pmat, s.nb, s.naos, s.naos, px_vo.blacs_ctxt); - psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, nullptr, false); // allocate X_full - psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, nullptr, false); // allocate X_full - psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, nullptr, false); // allocate X_full + psi::Psi> X_full_vo(s.nks, nstate, s.nocc * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_oo(s.nks, nstate, s.nocc * s.nocc, s.nocc * s.nvirt, false); // allocate X_full + psi::Psi> X_full_vv(s.nks, nstate, s.nvirt * s.nvirt, s.nocc * s.nvirt, false); // allocate X_full auto gather = [&](const psi::Psi>& X, psi::Psi>& X_full, const Parallel_2D& px, const int dim1, const int dim2) { @@ -266,7 +266,7 @@ TEST_F(DMTransTest, ComplexParallel) set_rand(c.get_pointer(), s.nks * pc.get_local_size()); // set c // compare to global matrix std::vector ngk_temp_2(s.nks, s.naos); - psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2.data(), true); + psi::Psi> c_full(s.nks, s.nocc + s.nvirt, s.naos, ngk_temp_2, true); for (int isk = 0;isk < s.nks;++isk) { c.fix_k(isk); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index dffc242205..1f0c593d1c 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -43,38 +43,38 @@ Psi::~Psi() } } -// Constructor 1-1: -template -Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) -{ - assert(nk_in > 0); - assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU - assert(nbs_in > 0); - - this->k_first = k_first_in; - this->allocate_inside = true; - - this->ngk = ngk_in; // modify later - // This function will delete the psi array first(if psi exist), then malloc a new memory for it. - resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); - - this->nk = nk_in; - this->nbands = nbd_in; - this->nbasis = nbs_in; +// // Constructor 1-1: +// template +// Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) +// { +// assert(nk_in > 0); +// assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU +// assert(nbs_in > 0); + +// this->k_first = k_first_in; +// this->allocate_inside = true; + +// this->ngk = ngk_in; // modify later +// // This function will delete the psi array first(if psi exist), then malloc a new memory for it. +// resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); + +// this->nk = nk_in; +// this->nbands = nbd_in; +// this->nbasis = nbs_in; - this->current_b = 0; - this->current_k = 0; - this->current_nbasis = nbs_in; - this->psi_current = this->psi; - this->psi_bias = 0; - - // Currently only GPU's implementation is supported for device recording! - base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); - base_device::information::record_device_memory(this->ctx, - GlobalV::ofs_device, - "Psi->resize()", - sizeof(T) * nk_in * nbd_in * nbs_in); -} +// this->current_b = 0; +// this->current_k = 0; +// this->current_nbasis = nbs_in; +// this->psi_current = this->psi; +// this->psi_bias = 0; + +// // Currently only GPU's implementation is supported for device recording! +// base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); +// base_device::information::record_device_memory(this->ctx, +// GlobalV::ofs_device, +// "Psi->resize()", +// sizeof(T) * nk_in * nbd_in * nbs_in); +// } // Constructor 1-2: template diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 6c96392c88..7d4aaa7d61 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -39,8 +39,8 @@ class Psi // Constructor 0: basic Psi(); - // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in); + // // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later + // Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in); // Constructor 1-2: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index b57e1b4caf..598cbe21bd 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -98,7 +98,7 @@ TEST_F(TestPsi, get_pointer_op_zero_complex_double) EXPECT_EQ(psi_object31->get_psi_bias(), 0); std::vector temp(ink, inbasis); - psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp.data(), true); + psi::Psi>* psi_temp = new psi::Psi>(ink, inbands, inbasis, temp, true); psi_temp->fix_k(0); EXPECT_EQ(psi_object31->get_current_nbas(), inbasis); delete psi_temp; @@ -241,10 +241,10 @@ TEST_F(TestPsi, range) TEST_F(TestPsi, band_first) { - const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, &ngk[0], false); - const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, &ngk[0], false); + const psi::Psi>* psi_band_c64 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_64 = new psi::Psi(ink, inbands, inbasis, ngk, false); + const psi::Psi>* psi_band_c32 = new psi::Psi>(ink, inbands, inbasis, ngk, false); + const psi::Psi* psi_band_32 = new psi::Psi(ink, inbands, inbasis, ngk, false); // set values: cover 4 different cases for (int ib = 0;ib < inbands;++ib) From c562b604574c674a74997d367af3ff92c7dd65bd Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 17 Jan 2025 00:25:23 -0800 Subject: [PATCH 12/14] fix bug --- source/module_ri/exx_lip.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_ri/exx_lip.hpp b/source/module_ri/exx_lip.hpp index 6be31a26b4..5e26446df4 100644 --- a/source/module_ri/exx_lip.hpp +++ b/source/module_ri/exx_lip.hpp @@ -112,7 +112,7 @@ Exx_Lip::Exx_Lip(const Exx_Info::Exx_Info_Lip& info_in, #endif this->k_pack->wf_wg.create(this->k_pack->kv_ptr->get_nks(),PARAM.inp.nbands); - this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk.data(), true); + this->k_pack->hvec_array = new psi::Psi(this->k_pack->kv_ptr->get_nks(), PARAM.inp.nbands, PARAM.globalv.nlocal, kv_ptr_in->ngk, true); // this->k_pack->hvec_array = new ModuleBase::ComplexMatrix[this->k_pack->kv_ptr->get_nks()]; // for( int ik=0; ikk_pack->kv_ptr->get_nks(); ++ik) // { From f26708b0ca862ac46234bf0fdabf89d386fcdfb5 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 17 Jan 2025 02:07:44 -0800 Subject: [PATCH 13/14] update psi --- source/module_psi/psi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 1f0c593d1c..f96d7a0543 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -154,7 +154,7 @@ Psi::Psi(const int nk_in, const bool k_first_in) { // Currently this function only supports nk_in == 1 when called within diagH_subspace_init. - assert(nk_in == 1); + // assert(nk_in == 1); this->k_first = k_first_in; this->allocate_inside = true; From efc0278c273c33780bacffea0e639dd5507d6230 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Fri, 17 Jan 2025 21:53:13 +0800 Subject: [PATCH 14/14] remove useless code --- source/module_psi/psi.cpp | 35 +---------------------------------- source/module_psi/psi.h | 5 +---- 2 files changed, 2 insertions(+), 38 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f96d7a0543..04118f7bcd 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -43,40 +43,7 @@ Psi::~Psi() } } -// // Constructor 1-1: -// template -// Psi::Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in) -// { -// assert(nk_in > 0); -// assert(nbd_in >= 0); // 187_PW_SDFT_ALL_GPU && 187_PW_MD_SDFT_ALL_GPU -// assert(nbs_in > 0); - -// this->k_first = k_first_in; -// this->allocate_inside = true; - -// this->ngk = ngk_in; // modify later -// // This function will delete the psi array first(if psi exist), then malloc a new memory for it. -// resize_memory_op()(this->ctx, this->psi, nk_in * static_cast(nbd_in) * nbs_in, "no_record"); - -// this->nk = nk_in; -// this->nbands = nbd_in; -// this->nbasis = nbs_in; - -// this->current_b = 0; -// this->current_k = 0; -// this->current_nbasis = nbs_in; -// this->psi_current = this->psi; -// this->psi_bias = 0; - -// // Currently only GPU's implementation is supported for device recording! -// base_device::information::print_device_info(this->ctx, GlobalV::ofs_device); -// base_device::information::record_device_memory(this->ctx, -// GlobalV::ofs_device, -// "Psi->resize()", -// sizeof(T) * nk_in * nbd_in * nbs_in); -// } - -// Constructor 1-2: +// Constructor 1: template Psi::Psi(const int nk_in, const int nbd_in, diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 7d4aaa7d61..75e13433ea 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -39,10 +39,7 @@ class Psi // Constructor 0: basic Psi(); - // // Constructor 1-1: specify nk, nbands, nbasis, ngk, and do not need to call resize() later - // Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in, const bool k_first_in); - - // Constructor 1-2: + // Constructor 1: Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector& ngk_in, const bool k_first_in); // Constructor 2-1: initialize a new psi from the given psi_in