Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 29 additions & 34 deletions source/module_base/array_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,51 @@ namespace ModuleBase
class Array_Pool
{
public:
Array_Pool();
Array_Pool(const int nr, const int nc);
Array_Pool() = default;
Array_Pool(const int nr_in, const int nc_in);
Array_Pool(Array_Pool<T>&& other);
Array_Pool& operator=(Array_Pool<T>&& other);
~Array_Pool();
Array_Pool(const Array_Pool<T>& other) = delete;
Array_Pool& operator=(const Array_Pool& other) = delete;

T** get_ptr_2D() const { return ptr_2D; }
T* get_ptr_1D() const { return ptr_1D; }
int get_nr() const { return nr; }
int get_nc() const { return nc; }
T* operator[](const int ir) const { return ptr_2D[ir]; }
T** get_ptr_2D() const { return this->ptr_2D; }
T* get_ptr_1D() const { return this->ptr_1D; }
int get_nr() const { return this->nr; }
int get_nc() const { return this->nc; }
T* operator[](const int ir) const { return this->ptr_2D[ir]; }
private:
T** ptr_2D;
T* ptr_1D;
int nr;
int nc;
T** ptr_2D = nullptr;
T* ptr_1D = nullptr;
int nr = 0;
int nc = 0;
};

template <typename T>
Array_Pool<T>::Array_Pool() : ptr_2D(nullptr), ptr_1D(nullptr), nr(0), nc(0)
Array_Pool<T>::Array_Pool(const int nr_in, const int nc_in) // Attention: uninitialized
: nr(nr_in),
nc(nc_in)
{
}

template <typename T>
Array_Pool<T>::Array_Pool(const int nr, const int nc) // Attention: uninitialized
{
this->nr = nr;
this->nc = nc;
ptr_1D = new T[nr * nc];
ptr_2D = new T*[nr];
this->ptr_1D = new T[nr * nc];
this->ptr_2D = new T*[nr];
for (int ir = 0; ir < nr; ++ir)
ptr_2D[ir] = &ptr_1D[ir * nc];
this->ptr_2D[ir] = &this->ptr_1D[ir * nc];
}

template <typename T>
Array_Pool<T>::~Array_Pool()
{
delete[] ptr_2D;
delete[] ptr_1D;
delete[] this->ptr_2D;
delete[] this->ptr_1D;
}

template <typename T>
Array_Pool<T>::Array_Pool(Array_Pool<T>&& other)
: ptr_2D(other.ptr_2D),
ptr_1D(other.ptr_1D),
nr(other.nr),
nc(other.nc)
{
ptr_2D = other.ptr_2D;
ptr_1D = other.ptr_1D;
nr = other.nr;
nc = other.nc;
other.ptr_2D = nullptr;
other.ptr_1D = nullptr;
other.nr = 0;
Expand All @@ -76,12 +71,12 @@ namespace ModuleBase
{
if (this != &other)
{
delete[] ptr_2D;
delete[] ptr_1D;
ptr_2D = other.ptr_2D;
ptr_1D = other.ptr_1D;
nr = other.nr;
nc = other.nc;
delete[] this->ptr_2D;
delete[] this->ptr_1D;
this->ptr_2D = other.ptr_2D;
this->ptr_1D = other.ptr_1D;
this->nr = other.nr;
this->nc = other.nc;
other.ptr_2D = nullptr;
other.ptr_1D = nullptr;
other.nr = 0;
Expand Down
3 changes: 2 additions & 1 deletion source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#include "module_base/ylm.h"
namespace Gint_Tools{
void cal_psir_ylm(
const Grid_Technique& gt, const int bxyz,
const Grid_Technique& gt,
const int bxyz,
const int na_grid, // number of atoms on this grid
const int grid_index, // 1d index of FFT index (i,j,k)
const double delta_r, // delta_r of the uniform FFT grid
Expand Down
51 changes: 35 additions & 16 deletions source/module_hamilt_lcao/module_gint/gint.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_hamilt_lcao/module_gint/grid_technique.h"
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"

#include <functional>

class Gint {
public:
~Gint();
Expand Down Expand Up @@ -64,6 +67,21 @@ class Gint {
const Grid_Technique* gridt = nullptr;
const UnitCell* ucell;

// psir_ylm_new = psir_func(psir_ylm)
// psir_func==nullptr means psir_ylm_new=psir_ylm
using T_psir_func = std::function<
const ModuleBase::Array_Pool<double>&(
const ModuleBase::Array_Pool<double> &psir_ylm,
const Grid_Technique &gt,
const int grid_index,
const int is,
const std::vector<int> &block_iw,
const std::vector<int> &block_size,
const std::vector<int> &block_index,
const ModuleBase::Array_Pool<bool> &cal_flag)>;
T_psir_func psir_func_1 = nullptr;
T_psir_func psir_func_2 = nullptr;

protected:
// variables related to FFT grid
int nbx;
Expand Down Expand Up @@ -152,17 +170,18 @@ class Gint {
hamilt::HContainer<double>* hR); // HContainer for storing the <phi_0 |
// V | phi_R> matrix element.

void cal_meshball_vlocal_k(int na_grid,
const int LD_pool,
int grid_index,
int* block_size,
int* block_index,
int* block_iw,
bool** cal_flag,
double** psir_ylm,
double** psir_vlbr3,
double* pvpR,
const UnitCell& ucell);
void cal_meshball_vlocal_k(
const int na_grid,
const int LD_pool,
const int grid_index,
const int*const block_size,
const int*const block_index,
const int*const block_iw,
const bool*const*const cal_flag,
const double*const*const psir_ylm,
const double*const*const psir_vlbr3,
double*const pvpR,
const UnitCell &ucell);

//------------------------------------------------------
// in gint_fvl.cpp
Expand Down Expand Up @@ -225,11 +244,11 @@ class Gint {
Gint_inout* inout);

void cal_meshball_rho(const int na_grid,
int* block_index,
int* vindex,
double** psir_ylm,
double** psir_DMR,
double* rho);
const int*const block_index,
const int*const vindex,
const double*const*const psir_ylm,
const double*const*const psir_DMR,
double*const rho);

void gint_kernel_tau(const int na_grid,
const int grid_index,
Expand Down
12 changes: 6 additions & 6 deletions source/module_hamilt_lcao/module_gint/gint_rho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
#include "module_hamilt_pw/hamilt_pwdft/global.h"

void Gint::cal_meshball_rho(const int na_grid,
int* block_index,
int* vindex,
double** psir_ylm,
double** psir_DMR,
double* rho)
const int*const block_index,
const int*const vindex,
const double*const*const psir_ylm,
const double*const*const psir_DMR,
double*const rho)
{
const int inc = 1;
// sum over mu to get density on grid
for (int ib = 0; ib < this->bxyz; ++ib)
{
double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
const double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
const int grid = vindex[ib];
rho[grid] += r;
}
Expand Down
37 changes: 22 additions & 15 deletions source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
const int ncyz = this->ny * this->nplane;
const double delta_r = this->gridt->dr_uniform;

#pragma omp parallel
#pragma omp parallel
{
std::vector<int> block_iw(max_size, 0);
std::vector<int> block_index(max_size+1, 0);
std::vector<int> block_size(max_size, 0);
std::vector<int> vindex(bxyz, 0);
std::vector<int> vindex(this->bxyz, 0);
#pragma omp for
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
{
const int na_grid = this->gridt->how_many_atoms[grid_index];
if (na_grid == 0) {
continue;
Expand All @@ -41,7 +42,7 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
block_size.data(),
cal_flag.get_ptr_2D());

// evaluate psi on grids
// evaluate psi on grids
const int LD_pool = block_index[na_grid];
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
Gint_Tools::cal_psir_ylm(*this->gridt,
Expand All @@ -56,6 +57,11 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {

for (int is = 0; is < inout->nspin_rho; ++is)
{
// psir_ylm_new = psir_func(psir_ylm)
// psir_func==nullptr means psir_ylm_new=psir_ylm
const ModuleBase::Array_Pool<double> &psir_ylm_1 = (!this->psir_func_1) ? psir_ylm : this->psir_func_1(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);
const ModuleBase::Array_Pool<double> &psir_ylm_2 = (!this->psir_func_2) ? psir_ylm : this->psir_func_2(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);

ModuleBase::Array_Pool<double> psir_DM(this->bxyz, LD_pool);
ModuleBase::GlobalFunc::ZEROS(psir_DM.get_ptr_1D(), this->bxyz * LD_pool);

Expand All @@ -68,13 +74,13 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
block_index.data(),
block_size.data(),
cal_flag.get_ptr_2D(),
psir_ylm.get_ptr_2D(),
psir_ylm_1.get_ptr_2D(),
psir_DM.get_ptr_2D(),
this->DMRGint[is],
inout->if_symm);

// do sum_mu g_mu(r)psi_mu(r) to get electron density on grid
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm_2.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
}
}
}
Expand All @@ -90,14 +96,15 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
const double delta_r = this->gridt->dr_uniform;


#pragma omp parallel
#pragma omp parallel
{
std::vector<int> block_iw(max_size, 0);
std::vector<int> block_index(max_size+1, 0);
std::vector<int> block_size(max_size, 0);
std::vector<int> vindex(bxyz, 0);
#pragma omp for
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
{
const int na_grid = this->gridt->how_many_atoms[grid_index];
if (na_grid == 0) {
continue;
Expand All @@ -112,19 +119,19 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
vindex.data());
//prepare block information
ModuleBase::Array_Pool<bool> cal_flag(this->bxyz,max_size);
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
block_iw.data(), block_index.data(), block_size.data(), cal_flag.get_ptr_2D());

//evaluate psi and dpsi on grids
//evaluate psi and dpsi on grids
const int LD_pool = block_index[na_grid];
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_x(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_y(this->bxyz, LD_pool);
ModuleBase::Array_Pool<double> dpsir_ylm_z(this->bxyz, LD_pool);

Gint_Tools::cal_dpsir_ylm(*this->gridt,
Gint_Tools::cal_dpsir_ylm(*this->gridt,
this->bxyz, na_grid, grid_index, delta_r,
block_index.data(), block_size.data(),
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
psir_ylm.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(),
Expand All @@ -146,7 +153,7 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
LD_pool,
grid_index, na_grid,
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
cal_flag.get_ptr_2D(),
dpsir_ylm_x.get_ptr_2D(),
dpsix_DM.get_ptr_2D(),
this->DMRGint[is],
Expand All @@ -166,13 +173,13 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
LD_pool,
grid_index, na_grid,
block_index.data(), block_size.data(),
cal_flag.get_ptr_2D(),
cal_flag.get_ptr_2D(),
dpsir_ylm_z.get_ptr_2D(),
dpsiz_DM.get_ptr_2D(),
this->DMRGint[is],
true);

//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
if(inout->job==Gint_Tools::job_type::tau)
{
this->cal_meshball_tau(
Expand Down
25 changes: 13 additions & 12 deletions source/module_hamilt_lcao/module_gint/gint_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,19 @@ ModuleBase::Array_Pool<double> get_psir_vlbr3(
const double* const* const psir_ylm); // psir_ylm[bxyz][LD_pool]

// sum_nu,R rho_mu,nu(R) psi_nu, for multi-k and gamma point
void mult_psi_DMR(const Grid_Technique& gt,
const int bxyz,
const int LD_pool,
const int& grid_index,
const int& na_grid,
const int* const block_index,
const int* const block_size,
bool** cal_flag,
double** psi,
double** psi_DMR,
const hamilt::HContainer<double>* DM,
const bool if_symm);
void mult_psi_DMR(
const Grid_Technique& gt,
const int bxyz,
const int LD_pool,
const int &grid_index,
const int &na_grid,
const int*const block_index,
const int*const block_size,
const bool*const*const cal_flag,
const double*const*const psi,
double*const*const psi_DMR,
const hamilt::HContainer<double>*const DM,
const bool if_symm);


// pair.first is the first index of the meshcell which is inside atoms ia1 and ia2.
Expand Down
Loading
Loading