Skip to content

Commit 9cc044e

Browse files
Refactor: remove some unused variables in module_gint (#5568)
* remove some unused variables * small change * remove some redundant lines * modify the class interface of grid_meshk * remove grid_index.h * modify some comment * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 2c2ed13 commit 9cc044e

File tree

15 files changed

+65
-154
lines changed

15 files changed

+65
-154
lines changed

source/module_base/vector3.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ template <class T> class Vector3
3636
Vector3(const Vector3<T> &v) : x(v.x), y(v.y), z(v.z){}; // Peize Lin add 2018-07-16
3737
explicit Vector3(const std::array<T,3> &v) :x(v[0]), y(v[1]), z(v[2]){}
3838

39+
template <typename U>
40+
explicit Vector3(const Vector3<U>& other) : x(static_cast<T>(other.x)), y(static_cast<T>(other.y)), z(static_cast<T>(other.z)) {}
41+
3942
Vector3(Vector3<T> &&v) noexcept : x(v.x), y(v.y), z(v.z) {}
4043

4144
/**

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,16 @@ class Gint {
133133

134134
//! calculate local potential contribution to the Hamiltonian
135135
//! na_grid: how many atoms on this (i,j,k) grid
136-
//! block_iw: dim is [na_grid], index of wave function for each block
137136
//! block_size: dim is [block_size], number of columns of a band
138137
//! block_index: dim is [na_grid+1], total number of atomic orbitals
139138
//! grid_index: index of grid group, for tracing iat
140139
//! cal_flag: dim is [bxyz][na_grid], whether the atom-grid distance is larger than cutoff
141140
//! psir_ylm: dim is [bxyz][LD_pool]
142141
//! psir_vlbr3: dim is [bxyz][LD_pool]
143142
//! hR: HContainer for storing the <phi_0|V|phi_R> matrix elements
144-
145143
void cal_meshball_vlocal(
146144
const int na_grid,
147145
const int LD_pool,
148-
const int* const block_iw,
149146
const int* const block_size,
150147
const int* const block_index,
151148
const int grid_index,
@@ -154,7 +151,6 @@ class Gint {
154151
const double* const* const psir_vlbr3,
155152
hamilt::HContainer<double>* hR);
156153

157-
158154
//! in gint_fvl.cpp
159155
//! calculate vl contributuion to force & stress via grid integrals
160156
void gint_kernel_force(const int na_grid,

source/module_hamilt_lcao/module_gint/gint_gamma_env.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ void Gint_Gamma::cal_env(const double* wfc, double* rho, UnitCell& ucell)
1919
}
2020
const int nbx = this->gridt->nbx;
2121
const int nby = this->gridt->nby;
22-
const int nbz_start = this->gridt->nbzp_start;
2322
const int nbz = this->gridt->nbzp;
2423
const int ncyz = this->ny * this->nplane; // mohan add 2012-03-25
2524
const int bxyz = this->bxyz;

source/module_hamilt_lcao/module_gint/gint_k_env.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "module_basis/module_ao/ORB_read.h"
77
#include "module_hamilt_pw/hamilt_pwdft/global.h"
88
#include "module_base/array_pool.h"
9+
#include "module_base/vector3.h"
910

1011
void Gint_k::cal_env_k(int ik,
1112
const std::complex<double>* psi_k,
@@ -26,7 +27,6 @@ void Gint_k::cal_env_k(int ik,
2627
}
2728
const int nbx = this->gridt->nbx;
2829
const int nby = this->gridt->nby;
29-
const int nbz_start = this->gridt->nbzp_start;
3030
const int nbz = this->gridt->nbzp;
3131
const int ncyz = this->ny * this->nplane; // mohan add 2012-03-25
3232

@@ -88,10 +88,7 @@ void Gint_k::cal_env_k(int ik,
8888

8989
// find R by which_unitcell and cal kphase
9090
const int id_ucell = this->gridt->which_unitcell[mcell_index1];
91-
const int Rx = this->gridt->ucell_index2x[id_ucell] + this->gridt->min_ucell_para[0];
92-
const int Ry = this->gridt->ucell_index2y[id_ucell] + this->gridt->min_ucell_para[1];
93-
const int Rz = this->gridt->ucell_index2z[id_ucell] + this->gridt->min_ucell_para[2];
94-
ModuleBase::Vector3<double> R((double)Rx, (double)Ry, (double)Rz);
91+
ModuleBase::Vector3<double> R(this->gridt->get_ucell_coords(id_ucell));
9592
// std::cout << "kvec_d: " << kvec_d[ik].x << " " << kvec_d[ik].y << " " << kvec_d[ik].z << std::endl;
9693
// std::cout << "kvec_c: " << kvec_c[ik].x << " " << kvec_c[ik].y << " " << kvec_c[ik].z << std::endl;
9794
// std::cout << "R: " << R.x << " " << R.y << " " << R.z << std::endl;

source/module_hamilt_lcao/module_gint/gint_vl.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "module_base/blas_connector.h"
99
#include "module_base/timer.h"
1010
#include "module_base/array_pool.h"
11+
#include "module_base/vector3.h"
1112
//#include <mkl_cblas.h>
1213

1314
#ifdef _OPENMP
@@ -22,7 +23,6 @@
2223
void Gint::cal_meshball_vlocal(
2324
const int na_grid, // how many atoms on this (i,j,k) grid
2425
const int LD_pool,
25-
const int*const block_iw, // block_iw[na_grid], index of wave functions for each block
2626
const int*const block_size, // block_size[na_grid], number of columns of a band
2727
const int*const block_index, // block_index[na_grid+1], count total number of atomis orbitals
2828
const int grid_index, // index of grid group, for tracing global atom index
@@ -41,18 +41,14 @@ void Gint::cal_meshball_vlocal(
4141
const int bcell1 = mcell_index + ia1;
4242
const int iat1 = this->gridt->which_atom[bcell1];
4343
const int id1 = this->gridt->which_unitcell[bcell1];
44-
const int r1x = this->gridt->ucell_index2x[id1];
45-
const int r1y = this->gridt->ucell_index2y[id1];
46-
const int r1z = this->gridt->ucell_index2z[id1];
44+
const ModuleBase::Vector3<int> r1 = this->gridt->get_ucell_coords(id1);
4745

4846
for(int ia2=0; ia2<na_grid; ++ia2)
4947
{
5048
const int bcell2 = mcell_index + ia2;
5149
const int iat2= this->gridt->which_atom[bcell2];
5250
const int id2 = this->gridt->which_unitcell[bcell2];
53-
const int r2x = this->gridt->ucell_index2x[id2];
54-
const int r2y = this->gridt->ucell_index2y[id2];
55-
const int r2z = this->gridt->ucell_index2z[id2];
51+
const ModuleBase::Vector3<int> r2 = this->gridt->get_ucell_coords(id2);
5652

5753
if(iat1<=iat2)
5854
{
@@ -77,12 +73,7 @@ void Gint::cal_meshball_vlocal(
7773
const int ib_length = last_ib-first_ib;
7874
if(ib_length<=0) { continue; }
7975

80-
// calculate the BaseMatrix of <iat1, iat2, R> atom-pair
81-
const int dRx = r1x - r2x;
82-
const int dRy = r1y - r2y;
83-
const int dRz = r1z - r2z;
84-
85-
const auto tmp_matrix = hR->find_matrix(iat1, iat2, dRx, dRy, dRz);
76+
const auto tmp_matrix = hR->find_matrix(iat1, iat2, r1-r2);
8677
if (tmp_matrix == nullptr)
8778
{
8879
continue;

source/module_hamilt_lcao/module_gint/gint_vl_cpu_interface.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void Gint::gint_kernel_vlocal(Gint_inout* inout) {
7272
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
7373
//and accumulates to the corresponding element in Hamiltonian
7474
this->cal_meshball_vlocal(
75-
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index,
75+
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index,
7676
cal_flag.get_ptr_2D(),psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(),
7777
&hRGint_thread);
7878
}
@@ -158,13 +158,13 @@ void Gint::gint_kernel_dvlocal(Gint_inout* inout) {
158158
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
159159
//and accumulates to the corresponding element in Hamiltonian
160160
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
161-
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
161+
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
162162
dpsir_ylm_x.get_ptr_2D(), &pvdpRx_thread);
163163
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
164-
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
164+
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
165165
dpsir_ylm_y.get_ptr_2D(), &pvdpRy_thread);
166166
this->cal_meshball_vlocal(na_grid, LD_pool, block_size.data(), block_index.data(),
167-
block_iw.data(), grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
167+
grid_index, cal_flag.get_ptr_2D(),psir_vlbr3.get_ptr_2D(),
168168
dpsir_ylm_z.get_ptr_2D(), &pvdpRz_thread);
169169
}
170170
#pragma omp critical(gint_k)
@@ -281,18 +281,18 @@ void Gint::gint_kernel_vlocal_meta(Gint_inout* inout) {
281281
//integrate (psi_mu*v(r)*dv) * psi_nu on grid
282282
//and accumulates to the corresponding element in Hamiltonian
283283
this->cal_meshball_vlocal(
284-
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
284+
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
285285
psir_ylm.get_ptr_2D(), psir_vlbr3.get_ptr_2D(), &hRGint_thread);
286286
//integrate (d/dx_i psi_mu*vk(r)*dv) * (d/dx_i psi_nu) on grid (x_i=x,y,z)
287287
//and accumulates to the corresponding element in Hamiltonian
288288
this->cal_meshball_vlocal(
289-
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
289+
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
290290
dpsir_ylm_x.get_ptr_2D(), dpsix_vlbr3.get_ptr_2D(), &hRGint_thread);
291291
this->cal_meshball_vlocal(
292-
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
292+
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
293293
dpsir_ylm_y.get_ptr_2D(), dpsiy_vlbr3.get_ptr_2D(), &hRGint_thread);
294294
this->cal_meshball_vlocal(
295-
na_grid, LD_pool, block_iw.data(), block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
295+
na_grid, LD_pool, block_size.data(), block_index.data(), grid_index, cal_flag.get_ptr_2D(),
296296
dpsir_ylm_z.get_ptr_2D(), dpsiz_vlbr3.get_ptr_2D(), &hRGint_thread);
297297
}
298298

source/module_hamilt_lcao/module_gint/grid_index.h

Lines changed: 0 additions & 20 deletions
This file was deleted.

source/module_hamilt_lcao/module_gint/grid_meshcell.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class Grid_MeshCell: public Grid_MeshK
1818
int nbzp_start,nbzp;
1919
// save the position of each meshcell.
2020
std::vector<std::vector<double>> meshcell_pos;
21+
22+
private:
23+
// latvec0 and GT are not used in current code.
24+
// these two variables may be removed in the future.
2125
ModuleBase::Matrix3 meshcell_latvec0;
2226
ModuleBase::Matrix3 meshcell_GT;
2327

@@ -45,7 +49,7 @@ class Grid_MeshCell: public Grid_MeshK
4549
const int &nbzp_in);
4650

4751
void init_latvec(const UnitCell &ucell);
48-
void init_meshcell_pos(void);
52+
void init_meshcell_pos();
4953

5054
};
5155

source/module_hamilt_lcao/module_gint/grid_meshk.cpp

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,14 @@ int Grid_MeshK::cal_Rindex(const int &u1, const int &u2, const int &u3)const
3131
return (x3 + x2 * this->nu3 + x1 * this->nu2 * this->nu3);
3232
}
3333

34-
void Grid_MeshK::init_ucell_para(void)
34+
ModuleBase::Vector3<int> Grid_MeshK::get_ucell_coords(const int &Rindex)const
3535
{
36-
this->max_ucell_para=std::vector<int>(3,0);
37-
this->max_ucell_para[0]=this->maxu1;
38-
this->max_ucell_para[1]=this->maxu2;
39-
this->max_ucell_para[2]=this->maxu3;
40-
41-
this->min_ucell_para=std::vector<int>(3,0);
42-
this->min_ucell_para[0]=this->minu1;
43-
this->min_ucell_para[1]=this->minu2;
44-
this->min_ucell_para[2]=this->minu3;
45-
46-
this->num_ucell_para=std::vector<int>(4,0);
47-
this->num_ucell_para[0]=this->nu1;
48-
this->num_ucell_para[1]=this->nu2;
49-
this->num_ucell_para[2]=this->nu3;
50-
this->num_ucell_para[3]=this->nutot;
51-
}
36+
const int x = ucell_index2x[Rindex];
37+
const int y = ucell_index2y[Rindex];
38+
const int z = ucell_index2z[Rindex];
5239

40+
return ModuleBase::Vector3<int>(x, y, z);
41+
}
5342

5443
void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dze,const int& nbx, const int& nby, const int& nbz)
5544
{
@@ -66,8 +55,10 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
6655
this->minu2 = (-dye+1) / nby - 1;
6756
this->minu3 = (-dze+1) / nbz - 1;
6857

69-
if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MaxUnitcell",maxu1,maxu2,maxu3);
70-
if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MinUnitcell",minu1,minu2,minu3);
58+
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MaxUnitcell",maxu1,maxu2,maxu3);
59+
}
60+
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"MinUnitcell",minu1,minu2,minu3);
61+
}
7162

7263
//--------------------------------------
7364
// number of unitcell in each direction.
@@ -77,9 +68,10 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
7768
this->nu3 = maxu3 - minu3 + 1;
7869
this->nutot = nu1 * nu2 * nu3;
7970

80-
init_ucell_para();
81-
if(PARAM.inp.test_gridt)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellNumber",nu1,nu2,nu3);
82-
if(PARAM.inp.out_level != "m") ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellTotal",nutot);
71+
if(PARAM.inp.test_gridt) {ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellNumber",nu1,nu2,nu3);
72+
}
73+
if(PARAM.inp.out_level != "m") { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"UnitCellTotal",nutot);
74+
}
8375

8476

8577
this->ucell_index2x = std::vector<int>(nutot, 0);
@@ -97,9 +89,9 @@ void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dz
9789
const int cell = cal_Rindex(i,j,k);
9890
assert(cell<nutot);
9991

100-
this->ucell_index2x[cell] = i-minu1;
101-
this->ucell_index2y[cell] = j-minu2;
102-
this->ucell_index2z[cell] = k-minu3;
92+
this->ucell_index2x[cell] = i;
93+
this->ucell_index2y[cell] = j;
94+
this->ucell_index2z[cell] = k;
10395

10496
}
10597
}

source/module_hamilt_lcao/module_gint/grid_meshk.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,23 @@
22
#define GRID_MESHK_H
33
#include "module_base/global_function.h"
44
#include "module_base/global_variable.h"
5+
#include "module_base/vector3.h"
56

67
class Grid_MeshK
78
{
89
public:
910
Grid_MeshK();
1011
~Grid_MeshK();
11-
// from 1D index to unitcell.
12-
std::vector<int> ucell_index2x;
13-
std::vector<int> ucell_index2y;
14-
std::vector<int> ucell_index2z;
15-
16-
// the unitcell parameters.
17-
std::vector<int> max_ucell_para;
18-
std::vector<int> min_ucell_para;
19-
std::vector<int> num_ucell_para;
2012

2113
// calculate the index of unitcell.
2214
int cal_Rindex(const int& u1, const int& u2, const int& u3)const;
2315

16+
ModuleBase::Vector3<int> get_ucell_coords(const int& Rindex)const;
17+
2418
/// move operator for the next ESolver to directly use its infomation
2519
Grid_MeshK& operator=(Grid_MeshK&& rhs) = default;
2620

27-
protected:
21+
private:
2822
// the max and the min unitcell.
2923
int maxu1;
3024
int maxu2;
@@ -40,11 +34,15 @@ class Grid_MeshK
4034
int nu3;
4135
int nutot;
4236

37+
// from 1D index to unitcell.
38+
std::vector<int> ucell_index2x;
39+
std::vector<int> ucell_index2y;
40+
std::vector<int> ucell_index2z;
41+
42+
protected:
4343
// calculate the extended unitcell.
4444
void cal_extended_cell(const int &dxe, const int &dye, const int &dze,
4545
const int& nbx, const int& nby, const int& nbz);
46-
// initialize the unitcell parameters.
47-
void init_ucell_para(void);
4846
};
4947

5048
#endif

0 commit comments

Comments
 (0)