Skip to content

Commit 3da2868

Browse files
authored
Feature: add interface Gint::psir_func (#5380)
1 parent 01beb19 commit 3da2868

File tree

9 files changed

+150
-116
lines changed

9 files changed

+150
-116
lines changed

source/module_base/array_pool.h

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,56 +15,51 @@ namespace ModuleBase
1515
class Array_Pool
1616
{
1717
public:
18-
Array_Pool();
19-
Array_Pool(const int nr, const int nc);
18+
Array_Pool() = default;
19+
Array_Pool(const int nr_in, const int nc_in);
2020
Array_Pool(Array_Pool<T>&& other);
2121
Array_Pool& operator=(Array_Pool<T>&& other);
2222
~Array_Pool();
2323
Array_Pool(const Array_Pool<T>& other) = delete;
2424
Array_Pool& operator=(const Array_Pool& other) = delete;
2525

26-
T** get_ptr_2D() const { return ptr_2D; }
27-
T* get_ptr_1D() const { return ptr_1D; }
28-
int get_nr() const { return nr; }
29-
int get_nc() const { return nc; }
30-
T* operator[](const int ir) const { return ptr_2D[ir]; }
26+
T** get_ptr_2D() const { return this->ptr_2D; }
27+
T* get_ptr_1D() const { return this->ptr_1D; }
28+
int get_nr() const { return this->nr; }
29+
int get_nc() const { return this->nc; }
30+
T* operator[](const int ir) const { return this->ptr_2D[ir]; }
3131
private:
32-
T** ptr_2D;
33-
T* ptr_1D;
34-
int nr;
35-
int nc;
32+
T** ptr_2D = nullptr;
33+
T* ptr_1D = nullptr;
34+
int nr = 0;
35+
int nc = 0;
3636
};
3737

3838
template <typename T>
39-
Array_Pool<T>::Array_Pool() : ptr_2D(nullptr), ptr_1D(nullptr), nr(0), nc(0)
39+
Array_Pool<T>::Array_Pool(const int nr_in, const int nc_in) // Attention: uninitialized
40+
: nr(nr_in),
41+
nc(nc_in)
4042
{
41-
}
42-
43-
template <typename T>
44-
Array_Pool<T>::Array_Pool(const int nr, const int nc) // Attention: uninitialized
45-
{
46-
this->nr = nr;
47-
this->nc = nc;
48-
ptr_1D = new T[nr * nc];
49-
ptr_2D = new T*[nr];
43+
this->ptr_1D = new T[nr * nc];
44+
this->ptr_2D = new T*[nr];
5045
for (int ir = 0; ir < nr; ++ir)
51-
ptr_2D[ir] = &ptr_1D[ir * nc];
46+
this->ptr_2D[ir] = &this->ptr_1D[ir * nc];
5247
}
5348

5449
template <typename T>
5550
Array_Pool<T>::~Array_Pool()
5651
{
57-
delete[] ptr_2D;
58-
delete[] ptr_1D;
52+
delete[] this->ptr_2D;
53+
delete[] this->ptr_1D;
5954
}
6055

6156
template <typename T>
6257
Array_Pool<T>::Array_Pool(Array_Pool<T>&& other)
58+
: ptr_2D(other.ptr_2D),
59+
ptr_1D(other.ptr_1D),
60+
nr(other.nr),
61+
nc(other.nc)
6362
{
64-
ptr_2D = other.ptr_2D;
65-
ptr_1D = other.ptr_1D;
66-
nr = other.nr;
67-
nc = other.nc;
6863
other.ptr_2D = nullptr;
6964
other.ptr_1D = nullptr;
7065
other.nr = 0;
@@ -76,12 +71,12 @@ namespace ModuleBase
7671
{
7772
if (this != &other)
7873
{
79-
delete[] ptr_2D;
80-
delete[] ptr_1D;
81-
ptr_2D = other.ptr_2D;
82-
ptr_1D = other.ptr_1D;
83-
nr = other.nr;
84-
nc = other.nc;
74+
delete[] this->ptr_2D;
75+
delete[] this->ptr_1D;
76+
this->ptr_2D = other.ptr_2D;
77+
this->ptr_1D = other.ptr_1D;
78+
this->nr = other.nr;
79+
this->nc = other.nc;
8580
other.ptr_2D = nullptr;
8681
other.ptr_1D = nullptr;
8782
other.nr = 0;

source/module_hamilt_lcao/module_gint/cal_psir_ylm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#include "module_base/ylm.h"
44
namespace Gint_Tools{
55
void cal_psir_ylm(
6-
const Grid_Technique& gt, const int bxyz,
6+
const Grid_Technique& gt,
7+
const int bxyz,
78
const int na_grid, // number of atoms on this grid
89
const int grid_index, // 1d index of FFT index (i,j,k)
910
const double delta_r, // delta_r of the uniform FFT grid

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "module_cell/module_neighbor/sltk_grid_driver.h"
1414
#include "module_hamilt_lcao/module_gint/grid_technique.h"
1515
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
16+
17+
#include <functional>
18+
1619
class Gint {
1720
public:
1821
~Gint();
@@ -64,6 +67,21 @@ class Gint {
6467
const Grid_Technique* gridt = nullptr;
6568
const UnitCell* ucell;
6669

70+
// psir_ylm_new = psir_func(psir_ylm)
71+
// psir_func==nullptr means psir_ylm_new=psir_ylm
72+
using T_psir_func = std::function<
73+
const ModuleBase::Array_Pool<double>&(
74+
const ModuleBase::Array_Pool<double> &psir_ylm,
75+
const Grid_Technique &gt,
76+
const int grid_index,
77+
const int is,
78+
const std::vector<int> &block_iw,
79+
const std::vector<int> &block_size,
80+
const std::vector<int> &block_index,
81+
const ModuleBase::Array_Pool<bool> &cal_flag)>;
82+
T_psir_func psir_func_1 = nullptr;
83+
T_psir_func psir_func_2 = nullptr;
84+
6785
protected:
6886
// variables related to FFT grid
6987
int nbx;
@@ -152,17 +170,18 @@ class Gint {
152170
hamilt::HContainer<double>* hR); // HContainer for storing the <phi_0 |
153171
// V | phi_R> matrix element.
154172

155-
void cal_meshball_vlocal_k(int na_grid,
156-
const int LD_pool,
157-
int grid_index,
158-
int* block_size,
159-
int* block_index,
160-
int* block_iw,
161-
bool** cal_flag,
162-
double** psir_ylm,
163-
double** psir_vlbr3,
164-
double* pvpR,
165-
const UnitCell& ucell);
173+
void cal_meshball_vlocal_k(
174+
const int na_grid,
175+
const int LD_pool,
176+
const int grid_index,
177+
const int*const block_size,
178+
const int*const block_index,
179+
const int*const block_iw,
180+
const bool*const*const cal_flag,
181+
const double*const*const psir_ylm,
182+
const double*const*const psir_vlbr3,
183+
double*const pvpR,
184+
const UnitCell &ucell);
166185

167186
//------------------------------------------------------
168187
// in gint_fvl.cpp
@@ -225,11 +244,11 @@ class Gint {
225244
Gint_inout* inout);
226245

227246
void cal_meshball_rho(const int na_grid,
228-
int* block_index,
229-
int* vindex,
230-
double** psir_ylm,
231-
double** psir_DMR,
232-
double* rho);
247+
const int*const block_index,
248+
const int*const vindex,
249+
const double*const*const psir_ylm,
250+
const double*const*const psir_DMR,
251+
double*const rho);
233252

234253
void gint_kernel_tau(const int na_grid,
235254
const int grid_index,

source/module_hamilt_lcao/module_gint/gint_rho.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
#include "module_hamilt_pw/hamilt_pwdft/global.h"
1212

1313
void Gint::cal_meshball_rho(const int na_grid,
14-
int* block_index,
15-
int* vindex,
16-
double** psir_ylm,
17-
double** psir_DMR,
18-
double* rho)
14+
const int*const block_index,
15+
const int*const vindex,
16+
const double*const*const psir_ylm,
17+
const double*const*const psir_DMR,
18+
double*const rho)
1919
{
2020
const int inc = 1;
2121
// sum over mu to get density on grid
2222
for (int ib = 0; ib < this->bxyz; ++ib)
2323
{
24-
double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
24+
const double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc);
2525
const int grid = vindex[ib];
2626
rho[grid] += r;
2727
}

source/module_hamilt_lcao/module_gint/gint_rho_cpu_interface.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
1010
const int ncyz = this->ny * this->nplane;
1111
const double delta_r = this->gridt->dr_uniform;
1212

13-
#pragma omp parallel
13+
#pragma omp parallel
1414
{
1515
std::vector<int> block_iw(max_size, 0);
1616
std::vector<int> block_index(max_size+1, 0);
1717
std::vector<int> block_size(max_size, 0);
18-
std::vector<int> vindex(bxyz, 0);
18+
std::vector<int> vindex(this->bxyz, 0);
1919
#pragma omp for
20-
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
20+
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
21+
{
2122
const int na_grid = this->gridt->how_many_atoms[grid_index];
2223
if (na_grid == 0) {
2324
continue;
@@ -41,7 +42,7 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
4142
block_size.data(),
4243
cal_flag.get_ptr_2D());
4344

44-
// evaluate psi on grids
45+
// evaluate psi on grids
4546
const int LD_pool = block_index[na_grid];
4647
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
4748
Gint_Tools::cal_psir_ylm(*this->gridt,
@@ -56,6 +57,11 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
5657

5758
for (int is = 0; is < inout->nspin_rho; ++is)
5859
{
60+
// psir_ylm_new = psir_func(psir_ylm)
61+
// psir_func==nullptr means psir_ylm_new=psir_ylm
62+
const ModuleBase::Array_Pool<double> &psir_ylm_1 = (!this->psir_func_1) ? psir_ylm : this->psir_func_1(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);
63+
const ModuleBase::Array_Pool<double> &psir_ylm_2 = (!this->psir_func_2) ? psir_ylm : this->psir_func_2(psir_ylm, *this->gridt, grid_index, is, block_iw, block_size, block_index, cal_flag);
64+
5965
ModuleBase::Array_Pool<double> psir_DM(this->bxyz, LD_pool);
6066
ModuleBase::GlobalFunc::ZEROS(psir_DM.get_ptr_1D(), this->bxyz * LD_pool);
6167

@@ -68,13 +74,13 @@ void Gint::gint_kernel_rho(Gint_inout* inout) {
6874
block_index.data(),
6975
block_size.data(),
7076
cal_flag.get_ptr_2D(),
71-
psir_ylm.get_ptr_2D(),
77+
psir_ylm_1.get_ptr_2D(),
7278
psir_DM.get_ptr_2D(),
7379
this->DMRGint[is],
7480
inout->if_symm);
7581

7682
// do sum_mu g_mu(r)psi_mu(r) to get electron density on grid
77-
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
83+
this->cal_meshball_rho(na_grid, block_index.data(), vindex.data(), psir_ylm_2.get_ptr_2D(), psir_DM.get_ptr_2D(), inout->rho[is]);
7884
}
7985
}
8086
}
@@ -90,14 +96,15 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
9096
const double delta_r = this->gridt->dr_uniform;
9197

9298

93-
#pragma omp parallel
99+
#pragma omp parallel
94100
{
95101
std::vector<int> block_iw(max_size, 0);
96102
std::vector<int> block_index(max_size+1, 0);
97103
std::vector<int> block_size(max_size, 0);
98104
std::vector<int> vindex(bxyz, 0);
99105
#pragma omp for
100-
for (int grid_index = 0; grid_index < this->nbxx; grid_index++) {
106+
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
107+
{
101108
const int na_grid = this->gridt->how_many_atoms[grid_index];
102109
if (na_grid == 0) {
103110
continue;
@@ -112,19 +119,19 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
112119
vindex.data());
113120
//prepare block information
114121
ModuleBase::Array_Pool<bool> cal_flag(this->bxyz,max_size);
115-
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
122+
Gint_Tools::get_block_info(*this->gridt, this->bxyz, na_grid, grid_index,
116123
block_iw.data(), block_index.data(), block_size.data(), cal_flag.get_ptr_2D());
117124

118-
//evaluate psi and dpsi on grids
125+
//evaluate psi and dpsi on grids
119126
const int LD_pool = block_index[na_grid];
120127
ModuleBase::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
121128
ModuleBase::Array_Pool<double> dpsir_ylm_x(this->bxyz, LD_pool);
122129
ModuleBase::Array_Pool<double> dpsir_ylm_y(this->bxyz, LD_pool);
123130
ModuleBase::Array_Pool<double> dpsir_ylm_z(this->bxyz, LD_pool);
124131

125-
Gint_Tools::cal_dpsir_ylm(*this->gridt,
132+
Gint_Tools::cal_dpsir_ylm(*this->gridt,
126133
this->bxyz, na_grid, grid_index, delta_r,
127-
block_index.data(), block_size.data(),
134+
block_index.data(), block_size.data(),
128135
cal_flag.get_ptr_2D(),
129136
psir_ylm.get_ptr_2D(),
130137
dpsir_ylm_x.get_ptr_2D(),
@@ -146,7 +153,7 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
146153
LD_pool,
147154
grid_index, na_grid,
148155
block_index.data(), block_size.data(),
149-
cal_flag.get_ptr_2D(),
156+
cal_flag.get_ptr_2D(),
150157
dpsir_ylm_x.get_ptr_2D(),
151158
dpsix_DM.get_ptr_2D(),
152159
this->DMRGint[is],
@@ -166,13 +173,13 @@ void Gint::gint_kernel_tau(Gint_inout* inout) {
166173
LD_pool,
167174
grid_index, na_grid,
168175
block_index.data(), block_size.data(),
169-
cal_flag.get_ptr_2D(),
176+
cal_flag.get_ptr_2D(),
170177
dpsir_ylm_z.get_ptr_2D(),
171178
dpsiz_DM.get_ptr_2D(),
172179
this->DMRGint[is],
173180
true);
174181

175-
//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
182+
//do sum_i,mu g_i,mu(r) * d/dx_i psi_mu(r) to get kinetic energy density on grid
176183
if(inout->job==Gint_Tools::job_type::tau)
177184
{
178185
this->cal_meshball_tau(

source/module_hamilt_lcao/module_gint/gint_tools.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,19 @@ ModuleBase::Array_Pool<double> get_psir_vlbr3(
284284
const double* const* const psir_ylm); // psir_ylm[bxyz][LD_pool]
285285

286286
// sum_nu,R rho_mu,nu(R) psi_nu, for multi-k and gamma point
287-
void mult_psi_DMR(const Grid_Technique& gt,
288-
const int bxyz,
289-
const int LD_pool,
290-
const int& grid_index,
291-
const int& na_grid,
292-
const int* const block_index,
293-
const int* const block_size,
294-
bool** cal_flag,
295-
double** psi,
296-
double** psi_DMR,
297-
const hamilt::HContainer<double>* DM,
298-
const bool if_symm);
287+
void mult_psi_DMR(
288+
const Grid_Technique& gt,
289+
const int bxyz,
290+
const int LD_pool,
291+
const int &grid_index,
292+
const int &na_grid,
293+
const int*const block_index,
294+
const int*const block_size,
295+
const bool*const*const cal_flag,
296+
const double*const*const psi,
297+
double*const*const psi_DMR,
298+
const hamilt::HContainer<double>*const DM,
299+
const bool if_symm);
299300

300301

301302
// pair.first is the first index of the meshcell which is inside atoms ia1 and ia2.

0 commit comments

Comments
 (0)