Skip to content

Commit 368c0fe

Browse files
committed
replace pvpr in gint_k with hcontainer
1 parent dcff74d commit 368c0fe

File tree

15 files changed

+247
-295
lines changed

15 files changed

+247
-295
lines changed

source/module_esolver/set_matrix_grid.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void ESolver_KS_LCAO<TK, TR>::set_matrix_grid(Record_adj& ra)
6868
this->pw_rho->nplane,
6969
this->pw_rho->startz_current,
7070
GlobalC::ucell,
71+
GlobalC::GridD,
7172
dr_uniform,
7273
rcuts,
7374
psi_u,

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 18 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
157157
if (this->DMRGint[is] != nullptr) {
158158
delete this->DMRGint[is];
159159
}
160+
if (this->hRGint_tmp[is] != nullptr) {
161+
delete this->hRGint_tmp[is];
162+
}
160163
this->DMRGint[is] = new hamilt::HContainer<double>(ucell_in.nat);
161164
}
165+
this->hRGint_tmp[0]
166+
= new hamilt::HContainer<double>(ucell_in.nat);
162167
#ifdef __MPI
163168
if (this->DMRGint_full != nullptr) {
164169
delete this->DMRGint_full;
@@ -167,132 +172,14 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
167172
#endif
168173
}
169174

170-
// prepare the row_index and col_index for construct AtomPairs, they are
171-
// same, name as orb_index
172-
std::vector<int> orb_index(ucell_in.nat + 1);
173-
orb_index[0] = 0;
174-
for (int i = 1; i < orb_index.size(); i++) {
175-
int type = ucell_in.iat2it[i - 1];
176-
orb_index[i] = orb_index[i - 1] + ucell_in.atoms[type].nw;
177-
}
178-
std::vector<int> orb_index_npol;
179-
if (npol == 2) {
180-
orb_index_npol.resize(ucell_in.nat + 1);
181-
orb_index_npol[0] = 0;
182-
for (int i = 1; i < orb_index_npol.size(); i++) {
183-
int type = ucell_in.iat2it[i - 1];
184-
orb_index_npol[i]
185-
= orb_index_npol[i - 1] + ucell_in.atoms[type].nw * npol;
186-
}
187-
}
188-
189175
if (PARAM.globalv.gamma_only_local && nspin != 4) {
190176
this->hRGint->fix_gamma();
191177
}
192-
for (int T1 = 0; T1 < ucell_in.ntype; ++T1) {
193-
const Atom* atom1 = &(ucell_in.atoms[T1]);
194-
for (int I1 = 0; I1 < atom1->na; ++I1) {
195-
auto& tau1 = atom1->tau[I1];
196-
197-
gd->Find_atom(ucell_in, tau1, T1, I1);
198-
199-
const int iat1 = ucell_in.itia2iat(T1, I1);
200-
201-
// for grid integration (on FFT box),
202-
// we only need to consider <phi_i | phi_j>,
203-
204-
// whether this atom is in this processor.
205-
if (this->gridt->in_this_processor[iat1]) {
206-
for (int ad = 0; ad < gd->getAdjacentNum() + 1; ++ad) {
207-
const int T2 = gd->getType(ad);
208-
const int I2 = gd->getNatom(ad);
209-
const int iat2 = ucell_in.itia2iat(T2, I2);
210-
const Atom* atom2 = &(ucell_in.atoms[T2]);
211-
212-
// NOTE: hRGint wil save total number of atom pairs,
213-
// if only upper triangle is saved, the lower triangle will
214-
// be lost in 2D-block parallelization. if the adjacent atom
215-
// is in this processor.
216-
if (this->gridt->in_this_processor[iat2]) {
217-
ModuleBase::Vector3<double> dtau
218-
= gd->getAdjacentTau(ad) - tau1;
219-
double distance = dtau.norm() * ucell_in.lat0;
220-
double rcut
221-
= this->gridt->rcuts[T1] + this->gridt->rcuts[T2];
222-
223-
// if(distance < rcut)
224-
// mohan reset this 2013-07-02 in Princeton
225-
// we should make absolutely sure that the distance is
226-
// smaller than rcuts[it] this should be consistant
227-
// with LCAO_nnr::cal_nnrg function typical example : 7
228-
// Bohr cutoff Si orbital in 14 Bohr length of cell.
229-
// distance = 7.0000000000000000
230-
// rcuts[it] = 7.0000000000000008
231-
if (distance < rcut - 1.0e-15) {
232-
// calculate R index
233-
auto& R_index = gd->getBox(ad);
234-
// insert this atom-pair into this->hRGint
235-
if (npol == 1) {
236-
hamilt::AtomPair<double> tmp_atom_pair(
237-
iat1,
238-
iat2,
239-
R_index.x,
240-
R_index.y,
241-
R_index.z,
242-
orb_index.data(),
243-
orb_index.data(),
244-
ucell_in.nat);
245-
this->hRGint->insert_pair(tmp_atom_pair);
246-
} else {
247-
// HR is complex and size is nw * npol
248-
hamilt::AtomPair<std::complex<double>>
249-
tmp_atom_pair(iat1,
250-
iat2,
251-
R_index.x,
252-
R_index.y,
253-
R_index.z,
254-
orb_index_npol.data(),
255-
orb_index_npol.data(),
256-
ucell_in.nat);
257-
this->hRGintCd->insert_pair(tmp_atom_pair);
258-
// DMR is double now and size is nw
259-
hamilt::AtomPair<double> tmp_dmR(
260-
iat1,
261-
iat2,
262-
R_index.x,
263-
R_index.y,
264-
R_index.z,
265-
orb_index.data(),
266-
orb_index.data(),
267-
ucell_in.nat);
268-
for (int is = 0; is < this->DMRGint.size();
269-
is++) {
270-
this->DMRGint[is]->insert_pair(tmp_dmR);
271-
}
272-
#ifdef __MPI
273-
hamilt::AtomPair<double> tmp_dmR_full(
274-
iat1,
275-
iat2,
276-
R_index.x,
277-
R_index.y,
278-
R_index.z,
279-
orb_index_npol.data(),
280-
orb_index_npol.data(),
281-
ucell_in.nat);
282-
// tmp DMR for transfer
283-
this->DMRGint_full->insert_pair(tmp_dmR_full);
284-
#endif
285-
}
286-
}
287-
} // end iat2
288-
} // end ad
289-
} // end iat
290-
} // end I1
291-
} // end T1
292178
if (npol == 1) {
179+
this->hRGint->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
293180
this->hRGint->allocate(nullptr, false);
294181
ModuleBase::Memory::record("Gint::hRGint",
295-
this->hRGint->get_memory_size());
182+
this->hRGint->get_memory_size());
296183
// initialize DMRGint with hRGint when NSPIN != 4
297184
for (int is = 0; is < this->DMRGint.size(); is++) {
298185
if (this->DMRGint[is] != nullptr) {
@@ -304,16 +191,20 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, Grid_Driver* gd, const int&
304191
this->DMRGint[0]->get_memory_size()
305192
* this->DMRGint.size());
306193
} else {
307-
this->hRGintCd->allocate(nullptr, 0);
308-
ModuleBase::Memory::record("Gint::hRGintCd",
309-
this->hRGintCd->get_memory_size());
310-
for (int is = 0; is < this->DMRGint.size(); is++) {
194+
this->hRGintCd->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
195+
for(int is = 0; is < nspin; is++) {
196+
this->hRGint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
197+
this->DMRGint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
198+
this->hRGint_tmp[is]->allocate(nullptr, false);
311199
this->DMRGint[is]->allocate(nullptr, false);
312200
}
201+
ModuleBase::Memory::record("Gint::hRGint_tmp",
202+
this->hRGint_tmp[0]->get_memory_size()*nspin);
313203
ModuleBase::Memory::record("Gint::DMRGint",
314-
this->DMRGint[0]->get_memory_size()
315-
* this->DMRGint.size());
204+
this->DMRGint[0]->get_memory_size()
205+
* this->DMRGint.size()*nspin);
316206
#ifdef __MPI
207+
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
317208
this->DMRGint_full->allocate(nullptr, false);
318209
ModuleBase::Memory::record("Gint::DMRGint_full",
319210
this->DMRGint_full->get_memory_size());
@@ -397,4 +288,4 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
397288
}
398289
}
399290
ModuleBase::timer::tick("Gint", "transfer_DMR");
400-
}
291+
}

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Gint {
152152
const UnitCell& ucell,
153153
hamilt::HContainer<double>* hR = nullptr);
154154

155-
void cal_meshball_vlocal_gamma(
155+
void cal_meshball_vlocal(
156156
const int na_grid, // how many atoms on this (i,j,k) grid
157157
const int LD_pool,
158158
const int* const block_iw, // block_iw[na_grid], index of wave
@@ -275,6 +275,7 @@ class Gint {
275275
= nullptr; // stores Hamiltonian in reduced format, for multi-l
276276
hamilt::HContainer<double>* hRGint
277277
= nullptr; // stores Hamiltonian in sparse format
278+
std::vector<hamilt::HContainer<double>*> hRGint_tmp; // size of vec is 4, only used when nspin = 4
278279
hamilt::HContainer<std::complex<double>>* hRGintCd
279280
= nullptr; // stores Hamiltonian in sparse format
280281
std::vector<hamilt::HContainer<double>*>

source/module_hamilt_lcao/module_gint/gint_k.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class Gint_k : public Gint {
4242
} else if (this->spin_now != -1) {
4343
int start_spin = -1;
4444
this->reset_spin(start_spin);
45-
this->destroy_pvpR();
46-
this->allocate_pvpR();
45+
// this->destroy_pvpR();
46+
// this->allocate_pvpR();
4747
}
4848
return;
4949
}
@@ -63,8 +63,7 @@ class Gint_k : public Gint {
6363
* then pass this->hRGint to Veff<OperatorLCAO>::hR
6464
*/
6565
void transfer_pvpR(hamilt::HContainer<double>* hR,
66-
const UnitCell* ucell_in,
67-
Grid_Driver* gd);
66+
const UnitCell* ucell_in, Grid_Driver* gd);
6867
void transfer_pvpR(hamilt::HContainer<std::complex<double>>* hR,
6968
const UnitCell* ucell_in,
7069
Grid_Driver* gd);

source/module_hamilt_lcao/module_gint/gint_k_pvpr.cpp

Lines changed: 26 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
#include "module_basis/module_ao/ORB_read.h"
1313
#include "module_cell/module_neighbor/sltk_grid_driver.h"
1414
#include "module_hamilt_pw/hamilt_pwdft/global.h"
15+
#include <mpi.h>
1516

1617
void Gint_k::allocate_pvpR(void)
1718
{
1819
ModuleBase::TITLE("Gint_k", "allocate_pvpR");
19-
2020
if (this->pvpR_alloc_flag)
2121
{
2222
return; // Liuxh add, 20181012
@@ -41,7 +41,6 @@ void Gint_k::allocate_pvpR(void)
4141
void Gint_k::destroy_pvpR(void)
4242
{
4343
ModuleBase::TITLE("Gint_k", "destroy_pvpR");
44-
4544
if (!pvpR_alloc_flag)
4645
{
4746
return;
@@ -60,113 +59,42 @@ void Gint_k::destroy_pvpR(void)
6059
#include "module_hamilt_lcao/module_hcontainer/hcontainer_funcs.h"
6160

6261
// transfer_pvpR, NSPIN = 1 or 2
63-
void Gint_k::transfer_pvpR(hamilt::HContainer<double>* hR, const UnitCell* ucell_in, Grid_Driver* gd)
62+
void Gint_k::transfer_pvpR(hamilt::HContainer<double>* hR, const UnitCell* ucell, Grid_Driver* gd)
6463
{
6564
ModuleBase::TITLE("Gint_k", "transfer_pvpR");
6665
ModuleBase::timer::tick("Gint_k", "transfer_pvpR");
6766

68-
if (!pvpR_alloc_flag || this->hRGint == nullptr)
69-
{
70-
ModuleBase::WARNING_QUIT("Gint_k::destroy_pvpR", "pvpR hasnot been allocated yet!");
71-
}
72-
this->hRGint->set_zero();
73-
74-
const int npol = PARAM.globalv.npol;
75-
const UnitCell& ucell = *ucell_in;
76-
for (int iat = 0; iat < ucell.nat; ++iat)
67+
for (int iap = 0; iap < this->hRGint->size_atom_pairs(); iap++)
7768
{
78-
const int T1 = ucell.iat2it[iat];
79-
const int I1 = ucell.iat2ia[iat];
69+
auto& ap = this->hRGint->get_atom_pair(iap);
70+
const int iat1 = ap.get_atom_i();
71+
const int iat2 = ap.get_atom_j();
72+
if (iat1 > iat2)
8073
{
81-
// atom in this grid piece.
82-
if (this->gridt->in_this_processor[iat])
83-
{
84-
Atom* atom1 = &ucell.atoms[T1];
85-
86-
// get the start positions of elements.
87-
const int DM_start = this->gridt->nlocstartg[iat];
88-
89-
// get the coordinates of adjacent atoms.
90-
auto& tau1 = ucell.atoms[T1].tau[I1];
91-
// gd.Find_atom(tau1);
92-
AdjacentAtomInfo adjs;
93-
gd->Find_atom(ucell, tau1, T1, I1, &adjs);
94-
// search for the adjacent atoms.
95-
int nad = 0;
96-
97-
for (int ad = 0; ad < adjs.adj_num + 1; ad++)
98-
{
99-
// get iat2
100-
const int T2 = adjs.ntype[ad];
101-
const int I2 = adjs.natom[ad];
102-
const int iat2 = ucell.itia2iat(T2, I2);
103-
104-
// adjacent atom is also on the grid.
105-
if (this->gridt->in_this_processor[iat2])
106-
{
107-
Atom* atom2 = &ucell.atoms[T2];
108-
auto dtau = adjs.adjacent_tau[ad] - tau1;
109-
double distance = dtau.norm() * ucell.lat0;
110-
double rcut = this->gridt->rcuts[T1] + this->gridt->rcuts[T2];
111-
112-
if (distance < rcut)
113-
{
114-
if (iat > iat2)
115-
{ // skip the lower triangle.
116-
nad++;
117-
continue;
118-
}
119-
// calculate the distance between iat1 and iat2.
120-
// ModuleBase::Vector3<double> dR = gd.getAdjacentTau(ad) - tau1;
121-
auto& dR = adjs.box[ad];
122-
// dR.x = adjs.box[ad].x;
123-
// dR.y = adjs.box[ad].y;
124-
// dR.z = adjs.box[ad].z;
125-
126-
int ixxx = DM_start + this->gridt->find_R2st[iat][nad];
127-
128-
hamilt::BaseMatrix<double>* tmp_matrix = this->hRGint->find_matrix(iat, iat2, dR);
74+
// fill lower triangle matrix with upper triangle matrix
75+
// the upper <IJR> is <iat2, iat1>
76+
const hamilt::AtomPair<double>* upper_ap = this->hRGint->find_pair(iat2, iat1);
77+
const hamilt::AtomPair<double>* lower_ap = this->hRGint->find_pair(iat1, iat2);
12978
#ifdef __DEBUG
130-
assert(tmp_matrix != nullptr);
79+
assert(upper_ap != nullptr);
13180
#endif
132-
double* tmp_pointer = tmp_matrix->get_pointer();
133-
const double* vijR = &pvpR_reduced[0][ixxx];
134-
for (int iw = 0; iw < atom1->nw; iw++)
135-
{
136-
for (int iw2 = 0; iw2 < atom2->nw; ++iw2)
137-
{
138-
*tmp_pointer++ = *vijR++;
139-
}
140-
}
141-
// save the lower triangle.
142-
if (iat < iat2) // skip iat == iat2
143-
{
144-
hamilt::BaseMatrix<double>* conj_matrix = this->hRGint->find_matrix(iat2, iat, -dR);
145-
#ifdef __DEBUG
146-
assert(conj_matrix != nullptr);
147-
#endif
148-
tmp_pointer = tmp_matrix->get_pointer();
149-
for (int iw = 0; iw < atom1->nw; iw++)
150-
{
151-
for (int iw2 = 0; iw2 < atom2->nw; ++iw2)
152-
{
153-
conj_matrix->get_value(iw2, iw) = *tmp_pointer++;
154-
}
155-
}
156-
}
157-
++nad;
158-
} // end distane<rcut
81+
for (int ir = 0; ir < ap.get_R_size(); ir++)
82+
{
83+
auto R_index = ap.get_R_index(ir);
84+
auto upper_matrix = upper_ap->find_matrix(-R_index);
85+
auto lower_matrix = lower_ap->find_matrix(R_index);
86+
for (int irow = 0; irow < upper_matrix->get_row_size(); ++irow)
87+
{
88+
for (int icol = 0; icol < upper_matrix->get_col_size(); ++icol)
89+
{
90+
lower_matrix->get_value(icol, irow) = upper_ap->get_value(irow, icol);
15991
}
160-
} // end ad
92+
}
16193
}
162-
} // end ia
163-
} // end it
164-
165-
// ===================================
166-
// transfer HR from Gint to Veff<OperatorLCAO<std::complex<double>, double>>
167-
// ===================================
94+
}
95+
}
16896
#ifdef __MPI
169-
int size;
97+
int size = 0;
17098
MPI_Comm_size(MPI_COMM_WORLD, &size);
17199
if (size == 1)
172100
{
@@ -180,7 +108,6 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<double>* hR, const UnitCell* ucell
180108
hR->add(*this->hRGint);
181109
#endif
182110
ModuleBase::timer::tick("Gint_k", "transfer_pvpR");
183-
184111
return;
185112
}
186113

0 commit comments

Comments
 (0)