Skip to content

Commit 32dd06d

Browse files
committed
modify the class interface of grid_meshk
1 parent 133629f commit 32dd06d

File tree

9 files changed

+44
-55
lines changed

9 files changed

+44
-55
lines changed

source/module_base/vector3.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ template <class T> class Vector3
3636
Vector3(const Vector3<T> &v) : x(v.x), y(v.y), z(v.z){}; // Peize Lin add 2018-07-16
3737
explicit Vector3(const std::array<T,3> &v) :x(v[0]), y(v[1]), z(v[2]){}
3838

39+
template <typename U>
40+
explicit Vector3(const Vector3<U>& other) : x(static_cast<T>(other.x)), y(static_cast<T>(other.y)), z(static_cast<T>(other.z)) {}
41+
3942
Vector3(Vector3<T> &&v) noexcept : x(v.x), y(v.y), z(v.z) {}
4043

4144
/**

source/module_hamilt_lcao/module_gint/gint_k_env.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "module_basis/module_ao/ORB_read.h"
77
#include "module_hamilt_pw/hamilt_pwdft/global.h"
88
#include "module_base/array_pool.h"
9+
#include "module_base/vector3.h"
910

1011
void Gint_k::cal_env_k(int ik,
1112
const std::complex<double>* psi_k,
@@ -87,10 +88,7 @@ void Gint_k::cal_env_k(int ik,
8788

8889
// find R by which_unitcell and cal kphase
8990
const int id_ucell = this->gridt->which_unitcell[mcell_index1];
90-
const int Rx = this->gridt->ucell_index2x[id_ucell];
91-
const int Ry = this->gridt->ucell_index2y[id_ucell];
92-
const int Rz = this->gridt->ucell_index2z[id_ucell];
93-
ModuleBase::Vector3<double> R((double)Rx, (double)Ry, (double)Rz);
91+
ModuleBase::Vector3<double> R(this->gridt->get_ucell_coords(id_ucell));
9492
// std::cout << "kvec_d: " << kvec_d[ik].x << " " << kvec_d[ik].y << " " << kvec_d[ik].z << std::endl;
9593
// std::cout << "kvec_c: " << kvec_c[ik].x << " " << kvec_c[ik].y << " " << kvec_c[ik].z << std::endl;
9694
// std::cout << "R: " << R.x << " " << R.y << " " << R.z << std::endl;

source/module_hamilt_lcao/module_gint/gint_vl.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "module_base/blas_connector.h"
99
#include "module_base/timer.h"
1010
#include "module_base/array_pool.h"
11+
#include "module_base/vector3.h"
1112
//#include <mkl_cblas.h>
1213

1314
#ifdef _OPENMP
@@ -40,18 +41,14 @@ void Gint::cal_meshball_vlocal(
4041
const int bcell1 = mcell_index + ia1;
4142
const int iat1 = this->gridt->which_atom[bcell1];
4243
const int id1 = this->gridt->which_unitcell[bcell1];
43-
const int r1x = this->gridt->ucell_index2x[id1];
44-
const int r1y = this->gridt->ucell_index2y[id1];
45-
const int r1z = this->gridt->ucell_index2z[id1];
44+
const ModuleBase::Vector3<int> r1 = this->gridt->get_ucell_coords(id1);
4645

4746
for(int ia2=0; ia2<na_grid; ++ia2)
4847
{
4948
const int bcell2 = mcell_index + ia2;
5049
const int iat2= this->gridt->which_atom[bcell2];
5150
const int id2 = this->gridt->which_unitcell[bcell2];
52-
const int r2x = this->gridt->ucell_index2x[id2];
53-
const int r2y = this->gridt->ucell_index2y[id2];
54-
const int r2z = this->gridt->ucell_index2z[id2];
51+
const ModuleBase::Vector3<int> r2 = this->gridt->get_ucell_coords(id2);
5552

5653
if(iat1<=iat2)
5754
{
@@ -76,12 +73,7 @@ void Gint::cal_meshball_vlocal(
7673
const int ib_length = last_ib-first_ib;
7774
if(ib_length<=0) { continue; }
7875

79-
// calculate the BaseMatrix of <iat1, iat2, R> atom-pair
80-
const int dRx = r1x - r2x;
81-
const int dRy = r1y - r2y;
82-
const int dRz = r1z - r2z;
83-
84-
const auto tmp_matrix = hR->find_matrix(iat1, iat2, dRx, dRy, dRz);
76+
const auto tmp_matrix = hR->find_matrix(iat1, iat2, r1-r2);
8577
if (tmp_matrix == nullptr)
8678
{
8779
continue;

source/module_hamilt_lcao/module_gint/grid_meshk.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ int Grid_MeshK::cal_Rindex(const int &u1, const int &u2, const int &u3)const
3131
return (x3 + x2 * this->nu3 + x1 * this->nu2 * this->nu3);
3232
}
3333

34+
ModuleBase::Vector3<int> Grid_MeshK::get_ucell_coords(const int &Rindex)const
35+
{
36+
const int x = ucell_index2x[Rindex];
37+
const int y = ucell_index2y[Rindex];
38+
const int z = ucell_index2z[Rindex];
39+
40+
return ModuleBase::Vector3<int>(x, y, z);
41+
}
42+
3443
void Grid_MeshK::cal_extended_cell(const int &dxe, const int &dye, const int &dze,const int& nbx, const int& nby, const int& nbz)
3544
{
3645
ModuleBase::TITLE("Grid_MeshK","cal_extended_cell");

source/module_hamilt_lcao/module_gint/grid_meshk.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@
22
#define GRID_MESHK_H
33
#include "module_base/global_function.h"
44
#include "module_base/global_variable.h"
5+
#include "module_base/vector3.h"
56

67
class Grid_MeshK
78
{
89
public:
910
Grid_MeshK();
1011
~Grid_MeshK();
11-
// from 1D index to unitcell.
12-
std::vector<int> ucell_index2x;
13-
std::vector<int> ucell_index2y;
14-
std::vector<int> ucell_index2z;
1512

1613
// calculate the index of unitcell.
1714
int cal_Rindex(const int& u1, const int& u2, const int& u3)const;
1815

16+
ModuleBase::Vector3<int> get_ucell_coords(const int& Rindex)const;
17+
1918
/// move operator for the next ESolver to directly use its infomation
2019
Grid_MeshK& operator=(Grid_MeshK&& rhs) = default;
2120

@@ -35,6 +34,11 @@ class Grid_MeshK
3534
int nu3;
3635
int nutot;
3736

37+
// from 1D index to unitcell.
38+
std::vector<int> ucell_index2x;
39+
std::vector<int> ucell_index2y;
40+
std::vector<int> ucell_index2z;
41+
3842
protected:
3943
// calculate the extended unitcell.
4044
void cal_extended_cell(const int &dxe, const int &dye, const int &dze,

source/module_hamilt_lcao/module_gint/gtask_force.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gint_force_gpu.h"
44
#include "module_base/ylm.h"
55
#include "module_hamilt_lcao/module_gint/gint_tools.h"
6+
#include "module_base/vector3.h"
67
namespace GintKernel
78
{
89

@@ -102,9 +103,7 @@ void alloc_mult_force(const hamilt::HContainer<double>* dm,
102103
const int mcell_index1 = bcell_start_index + atom1;
103104
const int iat1 = gridt.which_atom[mcell_index1];
104105
const int uc1 = gridt.which_unitcell[mcell_index1];
105-
const int rx1 = gridt.ucell_index2x[uc1];
106-
const int ry1 = gridt.ucell_index2y[uc1];
107-
const int rz1 = gridt.ucell_index2z[uc1];
106+
const ModuleBase::Vector3<int> r1 = gridt.get_ucell_coords(uc1);
108107
const int it1 = ucell.iat2it[iat1];
109108
const int nw1 = ucell.atoms[it1].nw;
110109

@@ -113,10 +112,8 @@ void alloc_mult_force(const hamilt::HContainer<double>* dm,
113112
const int mcell_index2 = bcell_start_index + atom2;
114113
const int iat2 = gridt.which_atom[mcell_index2];
115114
const int uc2 = gridt.which_unitcell[mcell_index2];
116-
const int rx2 = gridt.ucell_index2x[uc2];
117-
const int ry2 = gridt.ucell_index2y[uc2];
118-
const int rz2 = gridt.ucell_index2z[uc2];
119-
const int offset = dm->find_matrix_offset(iat1, iat2, rx1-rx2, ry1-ry2, rz1-rz2);
115+
const ModuleBase::Vector3<int> r2 = gridt.get_ucell_coords(uc2);
116+
const int offset = dm->find_matrix_offset(iat1, iat2, r1-r2);
120117
if (offset == -1)
121118
{
122119
continue;

source/module_hamilt_lcao/module_gint/gtask_rho.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "gint_rho_gpu.h"
22
#include "module_base/ylm.h"
33
#include "module_hamilt_lcao/module_gint/gint_tools.h"
4+
#include "module_base/vector3.h"
45
#include "omp.h"
56
namespace GintKernel
67
{
@@ -85,9 +86,7 @@ void alloc_mult_dot_rho(const hamilt::HContainer<double>* dm,
8586
const int mcell_index1 = bcell_start_index + atom1;
8687
const int iat1 = gridt.which_atom[mcell_index1];
8788
const int uc1 = gridt.which_unitcell[mcell_index1];
88-
const int rx1 = gridt.ucell_index2x[uc1];
89-
const int ry1 = gridt.ucell_index2y[uc1];
90-
const int rz1 = gridt.ucell_index2z[uc1];
89+
const ModuleBase::Vector3<int> r1 = gridt.get_ucell_coords(uc1);
9190
const int it1 = ucell.iat2it[iat1];
9291
const int nw1 = ucell.atoms[it1].nw;
9392

@@ -97,10 +96,8 @@ void alloc_mult_dot_rho(const hamilt::HContainer<double>* dm,
9796
const int mcell_index2 = bcell_start_index + atom2;
9897
const int iat2 = gridt.which_atom[mcell_index2];
9998
const int uc2 = gridt.which_unitcell[mcell_index2];
100-
const int rx2 = gridt.ucell_index2x[uc2];
101-
const int ry2 = gridt.ucell_index2y[uc2];
102-
const int rz2 = gridt.ucell_index2z[uc2];
103-
const int offset = dm->find_matrix_offset(iat1, iat2, rx1-rx2, ry1-ry2, rz1-rz2);
99+
const ModuleBase::Vector3<int> r2 = gridt.get_ucell_coords(uc2);
100+
const int offset = dm->find_matrix_offset(iat1, iat2, r1-r2);
104101
if (offset == -1)
105102
{
106103
continue;

source/module_hamilt_lcao/module_gint/gtask_vl.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gint_vl_gpu.h"
44
#include "module_base/ylm.h"
55
#include "module_hamilt_lcao/module_gint/gint_tools.h"
6+
#include "module_base/vector3.h"
67
namespace GintKernel
78
{
89

@@ -97,19 +98,15 @@ void alloc_mult_vlocal(const hamilt::HContainer<double>* hRGint,
9798
{
9899
const int iat1 = gridt.which_atom[bcell_start_index + atom1];
99100
const int uc1 = gridt.which_unitcell[bcell_start_index + atom1];
100-
const int rx1 = gridt.ucell_index2x[uc1];
101-
const int ry1 = gridt.ucell_index2y[uc1];
102-
const int rz1 = gridt.ucell_index2z[uc1];
101+
const ModuleBase::Vector3<int> r1 = gridt.get_ucell_coords(uc1);
103102
const int it1 = ucell.iat2it[iat1];
104103

105104
for (int atom2 = 0; atom2 < atom_num; atom2++)
106105
{
107106
const int iat2 = gridt.which_atom[bcell_start_index + atom2];
108107
const int uc2 = gridt.which_unitcell[bcell_start_index + atom2];
109-
const int rx2 = gridt.ucell_index2x[uc2];
110-
const int ry2 = gridt.ucell_index2y[uc2];
111-
const int rz2 = gridt.ucell_index2z[uc2];
112-
int offset = hRGint->find_matrix_offset(iat1, iat2, rx1-rx2, ry1-ry2, rz1-rz2);
108+
const ModuleBase::Vector3<int> r2 = gridt.get_ucell_coords(uc2);
109+
int offset = hRGint->find_matrix_offset(iat1, iat2, r1-r2);
113110
if (offset == -1)
114111
{
115112
continue;

source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ void mult_psi_DMR(
3535

3636
//! get cell R1, this step is redundant in gamma_only case.
3737
const int id1 = gt.which_unitcell[bcell1];
38-
const int R1x = gt.ucell_index2x[id1];
39-
const int R1y = gt.ucell_index2y[id1];
40-
const int R1z = gt.ucell_index2z[id1];
38+
const ModuleBase::Vector3<int> r1 = gt.get_ucell_coords(id1);
4139

4240
//! density
4341
if (if_symm)
@@ -74,20 +72,14 @@ void mult_psi_DMR(
7472
const int bcell2 = gt.bcell_start[grid_index] + ia2;
7573
const int iat2 = gt.which_atom[bcell2];
7674
const int id2 = gt.which_unitcell[bcell2];
77-
78-
//! get cell R2, this step is redundant in gamma_only case.
79-
const int R2x = gt.ucell_index2x[id2];
80-
const int R2y = gt.ucell_index2y[id2];
81-
const int R2z = gt.ucell_index2z[id2];
8275

83-
//! calculate the 'offset': R2 position relative
84-
//! to R1 atom, this step is redundant in gamma_only case.
85-
const int dRx = R1x - R2x;
86-
const int dRy = R1y - R2y;
87-
const int dRz = R1z - R2z;
76+
//---------------
77+
// get cell R2, this step is redundant in gamma_only case.
78+
//---------------
79+
const ModuleBase::Vector3<int> r2 = gt.get_ucell_coords(id2);
8880

8981
// get AtomPair
90-
const auto tmp_matrix = DM->find_matrix(iat1, iat2, dRx, dRy, dRz);
82+
const auto tmp_matrix = DM->find_matrix(iat1, iat2, r1-r2);
9183
if (tmp_matrix == nullptr)
9284
{
9385
continue;

0 commit comments

Comments
 (0)