From e895a00b774f2ffbab303bd24cfae2f65c788640 Mon Sep 17 00:00:00 2001 From: linpz Date: Thu, 31 Oct 2024 16:48:09 +0800 Subject: [PATCH] Feature: add interface Gint::psir_func --- source/module_base/array_pool.h | 63 +++++++++---------- .../module_gint/cal_psir_ylm.cpp | 3 +- source/module_hamilt_lcao/module_gint/gint.h | 51 ++++++++++----- .../module_gint/gint_rho.cpp | 12 ++-- .../module_gint/gint_rho_cpu_interface.cpp | 37 ++++++----- .../module_gint/gint_tools.h | 25 ++++---- .../module_gint/gint_vl.cpp | 46 +++++++------- .../module_gint/gint_vl_cpu_interface.cpp | 13 ++-- .../module_gint/mult_psi_dmr.cpp | 16 ++++- 9 files changed, 150 insertions(+), 116 deletions(-) diff --git a/source/module_base/array_pool.h b/source/module_base/array_pool.h index 2812ec72d0..0d9d175c1d 100644 --- a/source/module_base/array_pool.h +++ b/source/module_base/array_pool.h @@ -15,56 +15,51 @@ namespace ModuleBase class Array_Pool { public: - Array_Pool(); - Array_Pool(const int nr, const int nc); + Array_Pool() = default; + Array_Pool(const int nr_in, const int nc_in); Array_Pool(Array_Pool&& other); Array_Pool& operator=(Array_Pool&& other); ~Array_Pool(); Array_Pool(const Array_Pool& other) = delete; Array_Pool& operator=(const Array_Pool& other) = delete; - T** get_ptr_2D() const { return ptr_2D; } - T* get_ptr_1D() const { return ptr_1D; } - int get_nr() const { return nr; } - int get_nc() const { return nc; } - T* operator[](const int ir) const { return ptr_2D[ir]; } + T** get_ptr_2D() const { return this->ptr_2D; } + T* get_ptr_1D() const { return this->ptr_1D; } + int get_nr() const { return this->nr; } + int get_nc() const { return this->nc; } + T* operator[](const int ir) const { return this->ptr_2D[ir]; } private: - T** ptr_2D; - T* ptr_1D; - int nr; - int nc; + T** ptr_2D = nullptr; + T* ptr_1D = nullptr; + int nr = 0; + int nc = 0; }; template - Array_Pool::Array_Pool() : ptr_2D(nullptr), ptr_1D(nullptr), nr(0), nc(0) + Array_Pool::Array_Pool(const int nr_in, const int nc_in) // Attention: uninitialized + : nr(nr_in), + nc(nc_in) { - } - - template - Array_Pool::Array_Pool(const int nr, const int nc) // Attention: uninitialized - { - this->nr = nr; - this->nc = nc; - ptr_1D = new T[nr * nc]; - ptr_2D = new T*[nr]; + this->ptr_1D = new T[nr * nc]; + this->ptr_2D = new T*[nr]; for (int ir = 0; ir < nr; ++ir) - ptr_2D[ir] = &ptr_1D[ir * nc]; + this->ptr_2D[ir] = &this->ptr_1D[ir * nc]; } template Array_Pool::~Array_Pool() { - delete[] ptr_2D; - delete[] ptr_1D; + delete[] this->ptr_2D; + delete[] this->ptr_1D; } template Array_Pool::Array_Pool(Array_Pool&& other) + : ptr_2D(other.ptr_2D), + ptr_1D(other.ptr_1D), + nr(other.nr), + nc(other.nc) { - ptr_2D = other.ptr_2D; - ptr_1D = other.ptr_1D; - nr = other.nr; - nc = other.nc; other.ptr_2D = nullptr; other.ptr_1D = nullptr; other.nr = 0; @@ -76,12 +71,12 @@ namespace ModuleBase { if (this != &other) { - delete[] ptr_2D; - delete[] ptr_1D; - ptr_2D = other.ptr_2D; - ptr_1D = other.ptr_1D; - nr = other.nr; - nc = other.nc; + delete[] this->ptr_2D; + delete[] this->ptr_1D; + this->ptr_2D = other.ptr_2D; + this->ptr_1D = other.ptr_1D; + this->nr = other.nr; + this->nc = other.nc; other.ptr_2D = nullptr; other.ptr_1D = nullptr; other.nr = 0; diff --git a/source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp index a365be1fb7..d1bda7fa08 100644 --- a/source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp @@ -3,7 +3,8 @@ #include "module_base/ylm.h" namespace Gint_Tools{ void cal_psir_ylm( - const Grid_Technique& gt, const int bxyz, + const Grid_Technique& gt, + const int bxyz, const int na_grid, // number of atoms on this grid const int grid_index, // 1d index of FFT index (i,j,k) const double delta_r, // delta_r of the uniform FFT grid diff --git a/source/module_hamilt_lcao/module_gint/gint.h b/source/module_hamilt_lcao/module_gint/gint.h index 6be8adbc20..f7933ab639 100644 --- a/source/module_hamilt_lcao/module_gint/gint.h +++ b/source/module_hamilt_lcao/module_gint/gint.h @@ -13,6 +13,9 @@ #include "module_cell/module_neighbor/sltk_grid_driver.h" #include "module_hamilt_lcao/module_gint/grid_technique.h" #include "module_hamilt_lcao/module_hcontainer/hcontainer.h" + +#include + class Gint { public: ~Gint(); @@ -64,6 +67,21 @@ class Gint { const Grid_Technique* gridt = nullptr; const UnitCell* ucell; + // psir_ylm_new = psir_func(psir_ylm) + // psir_func==nullptr means psir_ylm_new=psir_ylm + using T_psir_func = std::function< + const ModuleBase::Array_Pool&( + const ModuleBase::Array_Pool &psir_ylm, + const Grid_Technique >, + const int grid_index, + const int is, + const std::vector &block_iw, + const std::vector &block_size, + const std::vector &block_index, + const ModuleBase::Array_Pool &cal_flag)>; + T_psir_func psir_func_1 = nullptr; + T_psir_func psir_func_2 = nullptr; + protected: // variables related to FFT grid int nbx; @@ -152,17 +170,18 @@ class Gint { hamilt::HContainer* hR); // HContainer for storing the matrix element. - void cal_meshball_vlocal_k(int na_grid, - const int LD_pool, - int grid_index, - int* block_size, - int* block_index, - int* block_iw, - bool** cal_flag, - double** psir_ylm, - double** psir_vlbr3, - double* pvpR, - const UnitCell& ucell); + void cal_meshball_vlocal_k( + const int na_grid, + const int LD_pool, + const int grid_index, + const int*const block_size, + const int*const block_index, + const int*const block_iw, + const bool*const*const cal_flag, + const double*const*const psir_ylm, + const double*const*const psir_vlbr3, + double*const pvpR, + const UnitCell &ucell); //------------------------------------------------------ // in gint_fvl.cpp @@ -225,11 +244,11 @@ class Gint { Gint_inout* inout); void cal_meshball_rho(const int na_grid, - int* block_index, - int* vindex, - double** psir_ylm, - double** psir_DMR, - double* rho); + const int*const block_index, + const int*const vindex, + const double*const*const psir_ylm, + const double*const*const psir_DMR, + double*const rho); void gint_kernel_tau(const int na_grid, const int grid_index, diff --git a/source/module_hamilt_lcao/module_gint/gint_rho.cpp b/source/module_hamilt_lcao/module_gint/gint_rho.cpp index 87d4b2fae2..17ca9d78c1 100644 --- a/source/module_hamilt_lcao/module_gint/gint_rho.cpp +++ b/source/module_hamilt_lcao/module_gint/gint_rho.cpp @@ -11,17 +11,17 @@ #include "module_hamilt_pw/hamilt_pwdft/global.h" void Gint::cal_meshball_rho(const int na_grid, - int* block_index, - int* vindex, - double** psir_ylm, - double** psir_DMR, - double* rho) + const int*const block_index, + const int*const vindex, + const double*const*const psir_ylm, + const double*const*const psir_DMR, + double*const rho) { const int inc = 1; // sum over mu to get density on grid for (int ib = 0; ib < this->bxyz; ++ib) { - double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc); + const double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc); const int grid = vindex[ib]; rho[grid] += r; } diff --git a/source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp b/source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp index 160df1377b..dfc4926b0d 100644 --- a/source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp +++ b/source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp @@ -10,14 +10,15 @@ void Gint::gint_kernel_rho(Gint_inout* inout) { const int ncyz = this->ny * this->nplane; const double delta_r = this->gridt->dr_uniform; -#pragma omp parallel +#pragma omp parallel { std::vector block_iw(max_size, 0); std::vector block_index(max_size+1, 0); std::vector block_size(max_size, 0); - std::vector vindex(bxyz, 0); + std::vector vindex(this->bxyz, 0); #pragma omp for - for (int grid_index = 0; grid_index < this->nbxx; grid_index++) { + for (int grid_index = 0; grid_index < this->nbxx; grid_index++) + { const int na_grid = this->gridt->how_many_atoms[grid_index]; if (na_grid == 0) { continue; @@ -41,7 +42,7 @@ void Gint::gint_kernel_rho(Gint_inout* inout) { block_size.data(), cal_flag.get_ptr_2D()); - // evaluate psi on grids + // evaluate psi on grids const int LD_pool = block_index[na_grid]; ModuleBase::Array_Pool psir_ylm(this->bxyz, LD_pool); Gint_Tools::cal_psir_ylm(*this->gridt, @@ -56,6 +57,11 @@ void Gint::gint_kernel_rho(Gint_inout* inout) { for (int is = 0; is < inout->nspin_rho; ++is) { + // psir_ylm_new = psir_func(psir_ylm) + // psir_func==nullptr means psir_ylm_new=psir_ylm + const ModuleBase::Array_Pool &psir_ylm_1 = (!this->psir_func_1) ? psir_ylm : this->psir_func_1(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag); + const ModuleBase::Array_Pool &psir_ylm_2 = (!this->psir_func_2) ? psir_ylm : this->psir_func_2(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag); + ModuleBase::Array_Pool psir_DM(this->bxyz, LD_pool); ModuleBase::GlobalFunc::ZEROS(psir_DM.get_ptr_1D(), this->bxyz * LD_pool); @@ -68,13 +74,13 @@ void Gint::gint_kernel_rho(Gint_inout* inout) { block_index.data(), block_size.data(), cal_flag.get_ptr_2D(), - psir_ylm.get_ptr_2D(), + psir_ylm_1.get_ptr_2D(), psir_DM.get_ptr_2D(), this->DMRGint[is], inout->if_symm); // do sum_mu g_mu(r)psi_mu(r) to get electron density on grid - this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]); + this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm_2.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]); } } } @@ -90,14 +96,15 @@ void Gint::gint_kernel_tau(Gint_inout* inout) { const double delta_r = this->gridt->dr_uniform; -#pragma omp parallel +#pragma omp parallel { std::vector block_iw(max_size, 0); std::vector block_index(max_size+1, 0); std::vector block_size(max_size, 0); std::vector vindex(bxyz, 0); #pragma omp for - for (int grid_index = 0; grid_index < this->nbxx; grid_index++) { + for (int grid_index = 0; grid_index < this->nbxx; grid_index++) + { const int na_grid = this->gridt->how_many_atoms[grid_index]; if (na_grid == 0) { continue; @@ -112,19 +119,19 @@ void Gint::gint_kernel_tau(Gint_inout* inout) { vindex.data()); //prepare block information ModuleBase::Array_Pool cal_flag(this->bxyz,max_size); - Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index, + Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index, block_iw.data(), block_index.data(), block_size.data(), cal_flag.get_ptr_2D()); - //evaluate psi and dpsi on grids + //evaluate psi and dpsi on grids const int LD_pool = block_index[na_grid]; ModuleBase::Array_Pool psir_ylm(this->bxyz, LD_pool); ModuleBase::Array_Pool dpsir_ylm_x(this->bxyz, LD_pool); ModuleBase::Array_Pool dpsir_ylm_y(this->bxyz, LD_pool); ModuleBase::Array_Pool dpsir_ylm_z(this->bxyz, LD_pool); - Gint_Tools::cal_dpsir_ylm(*this->gridt, + Gint_Tools::cal_dpsir_ylm(*this->gridt, this->bxyz, na_grid, grid_index, delta_r, - block_index.data(), block_size.data(), + block_index.data(), block_size.data(), cal_flag.get_ptr_2D(), psir_ylm.get_ptr_2D(), dpsir_ylm_x.get_ptr_2D(), @@ -146,7 +153,7 @@ void Gint::gint_kernel_tau(Gint_inout* inout) { LD_pool, grid_index, na_grid, block_index.data(), block_size.data(), - cal_flag.get_ptr_2D(), + cal_flag.get_ptr_2D(), dpsir_ylm_x.get_ptr_2D(), dpsix_DM.get_ptr_2D(), this->DMRGint[is], @@ -166,13 +173,13 @@ void Gint::gint_kernel_tau(Gint_inout* inout) { LD_pool, grid_index, na_grid, block_index.data(), block_size.data(), - cal_flag.get_ptr_2D(), + cal_flag.get_ptr_2D(), dpsir_ylm_z.get_ptr_2D(), dpsiz_DM.get_ptr_2D(), this->DMRGint[is], true); - //do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid + //do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid if(inout->job==Gint_Tools::job_type::tau) { this->cal_meshball_tau( diff --git a/source/module_hamilt_lcao/module_gint/gint_tools.h b/source/module_hamilt_lcao/module_gint/gint_tools.h index 7466cac2a2..3dd69e173c 100644 --- a/source/module_hamilt_lcao/module_gint/gint_tools.h +++ b/source/module_hamilt_lcao/module_gint/gint_tools.h @@ -284,18 +284,19 @@ ModuleBase::Array_Pool get_psir_vlbr3( const double* const* const psir_ylm); // psir_ylm[bxyz][LD_pool] // sum_nu,R rho_mu,nu(R) psi_nu, for multi-k and gamma point -void mult_psi_DMR(const Grid_Technique& gt, - const int bxyz, - const int LD_pool, - const int& grid_index, - const int& na_grid, - const int* const block_index, - const int* const block_size, - bool** cal_flag, - double** psi, - double** psi_DMR, - const hamilt::HContainer* DM, - const bool if_symm); +void mult_psi_DMR( + const Grid_Technique& gt, + const int bxyz, + const int LD_pool, + const int &grid_index, + const int &na_grid, + const int*const block_index, + const int*const block_size, + const bool*const*const cal_flag, + const double*const*const psi, + double*const*const psi_DMR, + const hamilt::HContainer*const DM, + const bool if_symm); // pair.first is the first index of the meshcell which is inside atoms ia1 and ia2. diff --git a/source/module_hamilt_lcao/module_gint/gint_vl.cpp b/source/module_hamilt_lcao/module_gint/gint_vl.cpp index 0a18f0c9dd..b435dd88bc 100644 --- a/source/module_hamilt_lcao/module_gint/gint_vl.cpp +++ b/source/module_hamilt_lcao/module_gint/gint_vl.cpp @@ -66,8 +66,7 @@ void Gint::cal_meshball_vlocal_gamma( } } const int ib_length = last_ib-first_ib; - if(ib_length<=0) { continue; -} + if(ib_length<=0) { continue; } // calculate the BaseMatrix of atom-pair hamilt::AtomPair* tmp_ap = hR->find_pair(iat1, iat2); @@ -81,14 +80,13 @@ void Gint::cal_meshball_vlocal_gamma( } const int n=block_size[ia2]; - //std::cout<<__FILE__<<__LINE__<<" "<get_row_size()<<" "<get_col_size()<ib_length/4) { dgemm_(&transa, &transb, &n, &m, &ib_length, &alpha, &psir_vlbr3[first_ib][block_index[ia2]], &LD_pool, &psir_ylm[first_ib][block_index[ia1]], &LD_pool, &beta, tmp_ap->get_pointer(0), &n); - //&GridVlocal[iw1_lo*lgd_now+iw2_lo], &lgd_now); + //&GridVlocal[iw1_lo*lgd_now+iw2_lo], &lgd_now); } else { @@ -96,32 +94,31 @@ void Gint::cal_meshball_vlocal_gamma( { if(cal_flag[ib][ia1] && cal_flag[ib][ia2]) { - int k=1; + int k=1; dgemm_(&transa, &transb, &n, &m, &k, &alpha, &psir_vlbr3[ib][block_index[ia2]], &LD_pool, &psir_ylm[ib][block_index[ia1]], &LD_pool, - &beta, tmp_ap->get_pointer(0), &n); + &beta, tmp_ap->get_pointer(0), &n); } } } - //std::cout<<__FILE__<<__LINE__<<" "<get_pointer(0)[2]<gridt->find_offset(id1, id2, iat1, iat2); - const int iatw = DM_start + this->gridt->find_R2st[iat1][offset]; + const int iatw = DM_start + this->gridt->find_R2st[iat1][offset]; if(cal_num>this->bxyz/4) { k=this->bxyz; dgemm_(&transa, &transb, &n, &m, &k, &alpha, - &psir_vlbr3[0][idx2], &LD_pool, + &psir_vlbr3[0][idx2], &LD_pool, &psir_ylm[0][idx1], &LD_pool, &beta, &pvpR[iatw], &n); } @@ -183,9 +179,9 @@ void Gint::cal_meshball_vlocal_k( { k=1; dgemm_(&transa, &transb, &n, &m, &k, &alpha, - &psir_vlbr3[ib][idx2], &LD_pool, + &psir_vlbr3[ib][idx2], &LD_pool, &psir_ylm[ib][idx1], &LD_pool, - &beta, &pvpR[iatw], &n); + &beta, &pvpR[iatw], &n); } } } diff --git a/source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp b/source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp index ed9b21434a..b20c1182d1 100644 --- a/source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp +++ b/source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp @@ -73,11 +73,16 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) { this->bxyz, na_grid, grid_index, delta_r, block_index.data(), block_size.data(), cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D()); - + + // psir_ylm_new=psir_func(psir_ylm) + // psir_func==nullptr means psir_ylm_new=psir_ylm + const ModuleBase::Array_Pool &psir_ylm_1 = (!this->psir_func_1) ? psir_ylm : this->psir_func_1(psir_ylm, *this->gridt, grid_index, 0, block_iw, block_size, block_index, cal_flag); + const ModuleBase::Array_Pool &psir_ylm_2 = (!this->psir_func_2) ? psir_ylm : this->psir_func_2(psir_ylm, *this->gridt, grid_index, 0, block_iw, block_size, block_index, cal_flag); + //calculating f_mu(r) = v(r)*psi_mu(r)*dv const ModuleBase::Array_Pool psir_vlbr3 = Gint_Tools::get_psir_vlbr3( this->bxyz, na_grid, LD_pool, block_index.data(), - cal_flag.get_ptr_2D(), vldr3.data(), psir_ylm.get_ptr_2D()); + cal_flag.get_ptr_2D(), vldr3.data(), psir_ylm_1.get_ptr_2D()); //integrate (psi_mu*v(r)*dv) * psi_nu on grid //and accumulates to the corresponding element in Hamiltonian @@ -85,14 +90,14 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) { { this->cal_meshball_vlocal_gamma( na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, - cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), + cal_flag.get_ptr_2D(), psir_ylm_2.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), hRGint_thread); } else { this->cal_meshball_vlocal_k( na_grid, LD_pool, grid_index, block_size.data(), block_index.data(), block_iw.data(), - cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), + cal_flag.get_ptr_2D(), psir_ylm_2.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), pvpR_thread,ucell); } diff --git a/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp b/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp index 4ec4ad8abf..b7facc2ef3 100644 --- a/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp +++ b/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp @@ -2,9 +2,19 @@ #include "module_base/timer.h" #include "module_base/ylm.h" namespace Gint_Tools{ -void mult_psi_DMR(const Grid_Technique& gt, const int bxyz, const int LD_pool, const int& grid_index, const int& na_grid, - const int* const block_index, const int* const block_size, bool** cal_flag, double** psi, - double** psi_DMR, const hamilt::HContainer* DM, const bool if_symm) +void mult_psi_DMR( + const Grid_Technique& gt, + const int bxyz, + const int LD_pool, + const int &grid_index, + const int &na_grid, + const int*const block_index, + const int*const block_size, + const bool*const*const cal_flag, + const double*const*const psi, + double*const*const psi_DMR, + const hamilt::HContainer*const DM, + const bool if_symm) { const UnitCell& ucell = *gt.ucell;