Skip to content

Commit e6a55ed

Browse files
authored
Removed the temporary variable hr_Gint_full_ when transitioning from 2D block parallelism to serial in Hcontainer (develop) (#6510)
* delete tem Hcontainer to reduce memory usage * simplify the compute code * change DM2D_tmp to dm2d_tmp, use vector instead of new * delete hr_gint_full_ to reduce memory * add serial code * fix bug * fix a not declare bug * fix serial bug * Fix the issue of not resetting to zero in the loop * resolve confilct * resolve confilct * Change the names to comply with the conventions
1 parent 1c54579 commit e6a55ed

21 files changed

+256
-194
lines changed

source/source_estate/elecstate_lcao.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<d
3535

3636
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
3737
#ifdef __OLD_GINT
38-
this->gint_k->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer DM2D to DM_grid in gint
38+
this->gint_k->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer dm2d to DM_grid in gint
3939
Gint_inout inout(this->charge->rho, Gint_Tools::job_type::rho, PARAM.inp.nspin);
4040
this->gint_k->cal_gint(&inout);
4141
#else
@@ -72,7 +72,7 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
7272
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
7373

7474
#ifdef __OLD_GINT
75-
this->gint_gamma->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer DM2D to DM_grid in gint
75+
this->gint_gamma->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer dm2d to DM_grid in gint
7676
Gint_inout inout(this->charge->rho, Gint_Tools::job_type::rho, PARAM.inp.nspin);
7777
this->gint_gamma->cal_gint(&inout);
7878
#else
@@ -140,7 +140,7 @@ void ElecStateLCAO<double>::dmToRho(std::vector<double*> pexsi_DM, std::vector<d
140140

141141
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
142142
#ifdef __OLD_GINT
143-
this->gint_gamma->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer DM2D to DM_grid in gint
143+
this->gint_gamma->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer dm2d to DM_grid in gint
144144
Gint_inout inout(this->charge->rho, Gint_Tools::job_type::rho, PARAM.inp.nspin);
145145
this->gint_gamma->cal_gint(&inout);
146146
#else

source/source_lcao/module_gint/gint.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Gint {
2626

2727
hamilt::HContainer<double>* get_hRGint() const { return hRGint; }
2828

29-
std::vector<hamilt::HContainer<double>*> get_DMRGint() const { return DMRGint; }
29+
std::vector<hamilt::HContainer<double>*> get_DMRGint() const { return dmr_gint; }
3030

3131
int get_ncxyz() const { return ncxyz; }
3232

@@ -58,14 +58,14 @@ class Gint {
5858
void initialize_pvpR(const UnitCell& unitcell, const Grid_Driver* gd, const int& nspin);
5959

6060
/**
61-
* @brief resize DMRGint to nspin and reallocate the memory
61+
* @brief resize dmr_gint to nspin and reallocate the memory
6262
*/
6363
void reset_DMRGint(const int& nspin);
6464

6565
/**
6666
* @brief transfer DMR (2D para) to DMR (Grid para) in elecstate_lcao.cpp
6767
*/
68-
void transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D);
68+
void transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> dm2d);
6969

7070
const Grid_Technique* gridt = nullptr;
7171
const UnitCell* ucell;
@@ -256,13 +256,13 @@ class Gint {
256256
hamilt::HContainer<double>* hRGint = nullptr;
257257

258258
//! size of vec is 4, only used when nspin = 4
259-
std::vector<hamilt::HContainer<double>*> hRGint_tmp;
259+
std::vector<hamilt::HContainer<double>*> hr_gint_tmp;
260260

261261
//! stores Hamiltonian in sparse format
262262
hamilt::HContainer<std::complex<double>>* hRGintCd = nullptr;
263263

264264
//! stores DMR in sparse format
265-
std::vector<hamilt::HContainer<double>*> DMRGint;
265+
std::vector<hamilt::HContainer<double>*> dmr_gint;
266266

267267
//! tmp tools used in transfer_DM2DtoGrid
268268
hamilt::HContainer<double>* dm2d_tmp = nullptr;

source/source_lcao/module_gint/gint_force_cpu_interface.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void Gint::gint_kernel_force(Gint_inout* inout) {
8282
cal_flag.get_ptr_2D(),
8383
psir_vlbr3.get_ptr_2D(),
8484
psir_vlbr3_DM.get_ptr_2D(),
85-
this->DMRGint[inout->ispin],
85+
this->dmr_gint[inout->ispin],
8686
false);
8787

8888
if(inout->isforce)
@@ -230,19 +230,19 @@ void Gint::gint_kernel_force_meta(Gint_inout* inout) {
230230
//calculating g_mu(r) = sum_nu rho_mu,nu f_nu(r)
231231
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, LD_pool, grid_index,
232232
na_grid, block_index.data(), block_size.data(), cal_flag.get_ptr_2D(),
233-
psir_vlbr3.get_ptr_2D(), psir_vlbr3_DM.get_ptr_2D(), this->DMRGint[inout->ispin], false);
233+
psir_vlbr3.get_ptr_2D(), psir_vlbr3_DM.get_ptr_2D(), this->dmr_gint[inout->ispin], false);
234234

235235
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, LD_pool, grid_index,
236236
na_grid, block_index.data(), block_size.data(), cal_flag.get_ptr_2D(),
237-
dpsir_x_vlbr3.get_ptr_2D(), dpsirx_v_DM.get_ptr_2D(), this->DMRGint[inout->ispin], false);
237+
dpsir_x_vlbr3.get_ptr_2D(), dpsirx_v_DM.get_ptr_2D(), this->dmr_gint[inout->ispin], false);
238238

239239
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, LD_pool, grid_index,
240240
na_grid, block_index.data(), block_size.data(), cal_flag.get_ptr_2D(),
241-
dpsir_y_vlbr3.get_ptr_2D(), dpsiry_v_DM.get_ptr_2D(), this->DMRGint[inout->ispin], false);
241+
dpsir_y_vlbr3.get_ptr_2D(), dpsiry_v_DM.get_ptr_2D(), this->dmr_gint[inout->ispin], false);
242242

243243
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, LD_pool, grid_index,
244244
na_grid, block_index.data(), block_size.data(), cal_flag.get_ptr_2D(),
245-
dpsir_z_vlbr3.get_ptr_2D(), dpsirz_v_DM.get_ptr_2D(), this->DMRGint[inout->ispin], false);
245+
dpsir_z_vlbr3.get_ptr_2D(), dpsirz_v_DM.get_ptr_2D(), this->dmr_gint[inout->ispin], false);
246246

247247
if(inout->isforce)
248248
{

source/source_lcao/module_gint/gint_gpu_interface.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void Gint::gpu_vlocal_interface(Gint_inout* inout) {
1818
ylmcoef[i] = ModuleBase::Ylm::ylmcoef[i];
1919
}
2020

21-
hamilt::HContainer<double>* hRGint_kernel = PARAM.inp.nspin != 4 ? this->hRGint : this->hRGint_tmp[inout->ispin];
21+
hamilt::HContainer<double>* hRGint_kernel = PARAM.inp.nspin != 4 ? this->hRGint : this->hr_gint_tmp[inout->ispin];
2222
GintKernel::gint_vl_gpu(hRGint_kernel,
2323
inout->vl,
2424
ylmcoef,
@@ -45,7 +45,7 @@ void Gint::gpu_rho_interface(Gint_inout* inout) {
4545
int nrxx = this->gridt->ncx * this->gridt->ncy * this->nplane;
4646
for (int is = 0; is < PARAM.inp.nspin; ++is) {
4747
ModuleBase::GlobalFunc::ZEROS(inout->rho[is], nrxx);
48-
GintKernel::gint_rho_gpu(this->DMRGint[is],
48+
GintKernel::gint_rho_gpu(this->dmr_gint[is],
4949
ylmcoef,
5050
dr,
5151
this->gridt->rcuts.data(),
@@ -76,7 +76,7 @@ void Gint::gpu_force_interface(Gint_inout* inout) {
7676
if (isforce || isstress) {
7777
std::vector<double> force(nat * 3, 0.0);
7878
std::vector<double> stress(6, 0.0);
79-
GintKernel::gint_fvl_gpu(this->DMRGint[inout->ispin],
79+
GintKernel::gint_fvl_gpu(this->dmr_gint[inout->ispin],
8080
inout->vl,
8181
force.data(),
8282
stress.data(),

source/source_lcao/module_gint/gint_k_pvpr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<std::complex<double>>* hR,
8989
{
9090
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
9191
hamilt::AtomPair<std::complex<double>>* lower_ap = this->hRGintCd->find_pair(iat2, iat1);
92-
const hamilt::AtomPair<double>* ap_nspin_0 = this->hRGint_tmp[0]->find_pair(iat1, iat2);
93-
const hamilt::AtomPair<double>* ap_nspin_3 = this->hRGint_tmp[3]->find_pair(iat1, iat2);
92+
const hamilt::AtomPair<double>* ap_nspin_0 = this->hr_gint_tmp[0]->find_pair(iat1, iat2);
93+
const hamilt::AtomPair<double>* ap_nspin_3 = this->hr_gint_tmp[3]->find_pair(iat1, iat2);
9494
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
9595
{
9696
const auto R_index = upper_ap->get_R_index(ir);
@@ -110,8 +110,8 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<std::complex<double>>* hR,
110110

111111
if (PARAM.globalv.domag)
112112
{
113-
const hamilt::AtomPair<double>* ap_nspin_1 = this->hRGint_tmp[1]->find_pair(iat1, iat2);
114-
const hamilt::AtomPair<double>* ap_nspin_2 = this->hRGint_tmp[2]->find_pair(iat1, iat2);
113+
const hamilt::AtomPair<double>* ap_nspin_1 = this->hr_gint_tmp[1]->find_pair(iat1, iat2);
114+
const hamilt::AtomPair<double>* ap_nspin_2 = this->hr_gint_tmp[2]->find_pair(iat1, iat2);
115115
const auto mat_nspin_1 = ap_nspin_1->find_matrix(R_index);
116116
const auto mat_nspin_2 = ap_nspin_2->find_matrix(R_index);
117117
for (int irow = 0; irow < mat_nspin_1->get_row_size(); ++irow)

source/source_lcao/module_gint/gint_old.cpp

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ Gint::~Gint() {
2424

2525
delete this->hRGint;
2626
delete this->hRGintCd;
27-
// in gamma_only case, DMRGint.size()=0,
28-
// in multi-k case, DMRGint.size()=nspin
29-
for (int is = 0; is < this->DMRGint.size(); is++) {
30-
delete this->DMRGint[is];
27+
// in gamma_only case, dmr_gint.size()=0,
28+
// in multi-k case, dmr_gint.size()=nspin
29+
for (int is = 0; is < this->dmr_gint.size(); is++) {
30+
delete this->dmr_gint[is];
3131
}
32-
for(int is = 0; is < this->hRGint_tmp.size(); is++) {
33-
delete this->hRGint_tmp[is];
32+
for(int is = 0; is < this->hr_gint_tmp .size(); is++) {
33+
delete this->hr_gint_tmp [is];
3434
}
3535
#ifdef __MPI
3636
delete this->dm2d_tmp;
@@ -141,13 +141,12 @@ void Gint::prep_grid(const Grid_Technique& gt,
141141
void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, const int& nspin)
142142
{
143143
ModuleBase::TITLE("Gint", "initialize_pvpR");
144-
145144
int npol = 1;
146-
// there is the only resize code of DMRGint
147-
if (this->DMRGint.size() == 0) {
148-
this->DMRGint.resize(nspin);
145+
// there is the only resize code of dmr_gint
146+
if (this->dmr_gint.size() == 0) {
147+
this->dmr_gint.resize(nspin);
149148
}
150-
hRGint_tmp.resize(nspin);
149+
hr_gint_tmp.resize(nspin);
151150
if (nspin != 4) {
152151
if (this->hRGint != nullptr) {
153152
delete this->hRGint;
@@ -161,22 +160,21 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
161160
this->hRGintCd
162161
= new hamilt::HContainer<std::complex<double>>(ucell_in.nat);
163162
for (int is = 0; is < nspin; is++) {
164-
if (this->DMRGint[is] != nullptr) {
165-
delete this->DMRGint[is];
163+
if (this->dmr_gint[is] != nullptr) {
164+
delete this->dmr_gint[is];
166165
}
167-
if (this->hRGint_tmp[is] != nullptr) {
168-
delete this->hRGint_tmp[is];
166+
if (this->hr_gint_tmp[is] != nullptr) {
167+
delete this->hr_gint_tmp[is];
169168
}
170-
this->DMRGint[is] = new hamilt::HContainer<double>(ucell_in.nat);
171-
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
169+
this->dmr_gint[is] = new hamilt::HContainer<double>(ucell_in.nat);
170+
this->hr_gint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172171
}
173172
#ifdef __MPI
174173
if (this->dm2d_tmp != nullptr) {
175174
delete this->dm2d_tmp;
176175
}
177176
#endif
178177
}
179-
180178
if (PARAM.globalv.gamma_only_local && nspin != 4) {
181179
this->hRGint->fix_gamma();
182180
}
@@ -185,94 +183,100 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
185183
this->hRGint->allocate(nullptr, true);
186184
ModuleBase::Memory::record("Gint::hRGint",
187185
this->hRGint->get_memory_size());
188-
// initialize DMRGint with hRGint when NSPIN != 4
189-
for (int is = 0; is < this->DMRGint.size(); is++) {
190-
if (this->DMRGint[is] != nullptr) {
191-
delete this->DMRGint[is];
186+
// initialize dmr_gint with hRGint when NSPIN != 4
187+
for (int is = 0; is < this->dmr_gint.size(); is++) {
188+
if (this->dmr_gint[is] != nullptr) {
189+
delete this->dmr_gint[is];
192190
}
193-
this->DMRGint[is] = new hamilt::HContainer<double>(*this->hRGint);
191+
this->dmr_gint[is] = new hamilt::HContainer<double>(*this->hRGint);
194192
}
195-
ModuleBase::Memory::record("Gint::DMRGint",
196-
this->DMRGint[0]->get_memory_size()
197-
* this->DMRGint.size());
193+
ModuleBase::Memory::record("Gint::dmr_gint",
194+
this->dmr_gint[0]->get_memory_size()
195+
* this->dmr_gint.size());
198196
} else {
199197
this->hRGintCd->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
200198
this->hRGintCd->allocate(nullptr, true);
201199
for(int is = 0; is < nspin; is++) {
202-
this->hRGint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
203-
this->DMRGint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
204-
this->hRGint_tmp[is]->allocate(nullptr, true);
205-
this->DMRGint[is]->allocate(nullptr, true);
200+
this->hr_gint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
201+
this->dmr_gint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
202+
this->hr_gint_tmp[is]->allocate(nullptr, true);
203+
this->dmr_gint[is]->allocate(nullptr, true);
206204
}
207-
ModuleBase::Memory::record("Gint::hRGint_tmp",
208-
this->hRGint_tmp[0]->get_memory_size()*nspin);
209-
ModuleBase::Memory::record("Gint::DMRGint",
210-
this->DMRGint[0]->get_memory_size()
211-
* this->DMRGint.size()*nspin);
205+
ModuleBase::Memory::record("Gint::hr_gint_tmp",
206+
this->hr_gint_tmp[0]->get_memory_size()*nspin);
207+
ModuleBase::Memory::record("Gint::dmr_gint",
208+
this->dmr_gint[0]->get_memory_size()
209+
* this->dmr_gint.size()*nspin);
212210
}
213211
}
214212

215213
void Gint::reset_DMRGint(const int& nspin)
216214
{
217215
if (this->hRGint)
218216
{
219-
for (auto& d : this->DMRGint) { delete d; }
220-
this->DMRGint.resize(nspin);
221-
this->DMRGint.shrink_to_fit();
222-
for (auto& d : this->DMRGint) { d = new hamilt::HContainer<double>(*this->hRGint); }
217+
for (auto& d : this->dmr_gint) { delete d; }
218+
this->dmr_gint.resize(nspin);
219+
this->dmr_gint.shrink_to_fit();
220+
for (auto& d : this->dmr_gint) { d = new hamilt::HContainer<double>(*this->hRGint); }
223221
if (nspin == 4)
224222
{
225-
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
223+
for (auto& d : this->dmr_gint) { d->allocate(nullptr, false); }
226224
#ifdef __MPI
227225
delete this->dm2d_tmp;
228226
#endif
229227
}
230228
}
231229
}
232230

233-
void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
231+
void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> dm2d) {
234232
ModuleBase::TITLE("Gint", "transfer_DMR");
235-
236-
// To check whether input parameter DM2D has been initialized
233+
// To check whether input parameter dm2d has been initialized
237234
#ifdef __DEBUG
238-
assert(!DM2D.empty()
239-
&& "Input parameter DM2D has not been initialized while calling "
235+
assert(!dm2d.empty()
236+
&& "Input parameter dm2d has not been initialized while calling "
240237
"function transfer_DM2DtoGrid!");
241238
#endif
242-
243239
ModuleBase::timer::tick("Gint", "transfer_DMR");
244240
if (PARAM.inp.nspin != 4) {
245-
for (int is = 0; is < this->DMRGint.size(); is++) {
241+
for (int is = 0; is < this->dmr_gint.size(); is++) {
246242
#ifdef __MPI
247-
hamilt::transferParallels2Serials(*DM2D[is], DMRGint[is]);
243+
hamilt::transferParallels2Serials(*dm2d[is], dmr_gint[is]);
248244
#else
249-
this->DMRGint[is]->set_zero();
250-
this->DMRGint[is]->add(*DM2D[is]);
245+
this->dmr_gint[is]->set_zero();
246+
this->dmr_gint[is]->add(*dm2d[is]);
251247
#endif
252248
}
253249
} else // NSPIN=4 case
254250
{
255-
#ifdef __MPI
256251
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257252
const int row_set[4] = {0, 0, 1, 1};
258253
const int col_set[4] = {0, 1, 0, 1};
259-
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
260-
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
261-
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
262-
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
254+
int mg = dm2d[0]->get_paraV()->get_global_row_size()/2;
255+
int ng = dm2d[0]->get_paraV()->get_global_col_size()/2;
256+
int nb = dm2d[0]->get_paraV()->get_block_size()/2;
257+
auto ijr_info = dm2d[0]->get_ijr_info();
258+
#ifdef __MPI
259+
int blacs_ctxt = dm2d[0]->get_paraV()->blacs_ctxt;
263260
std::vector<int> iat2iwt(ucell->nat);
264261
for (int iat = 0; iat < ucell->nat; iat++) {
265262
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
266263
}
267264
Parallel_Orbitals pv{};
268265
pv.set(mg, ng, nb, blacs_ctxt);
269266
pv.set_atomic_trace(iat2iwt.data(), ucell->nat, mg);
270-
auto ijr_info = DM2D[0]->get_ijr_info();
271267
this-> dm2d_tmp = new hamilt::HContainer<double>(&pv, nullptr, &ijr_info);
268+
#else
269+
if (this->dm2d_tmp != nullptr) {
270+
delete this->dm2d_tmp;
271+
}
272+
this-> dm2d_tmp = new hamilt::HContainer<double>(*this->hRGint);
273+
this-> dm2d_tmp -> insert_ijrs(this->gridt->get_ijr_info(), *(this->ucell));
274+
this-> dm2d_tmp -> allocate(nullptr, true);
275+
#endif
272276
ModuleBase::Memory::record("Gint::dm2d_tmp", this->dm2d_tmp->get_memory_size());
273277
for (int is = 0; is < 4; is++){
274-
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
275-
auto& ap = DM2D[0]->get_atom_pair(iap);
278+
for (int iap = 0; iap < dm2d[0]->size_atom_pairs(); ++iap) {
279+
auto& ap = dm2d[0]->get_atom_pair(iap);
276280
int iat1 = ap.get_atom_i();
277281
int iat2 = ap.get_atom_j();
278282
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
@@ -288,13 +292,15 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
288292
}
289293
}
290294
}
291-
hamilt::transferParallels2Serials( *(this->dm2d_tmp), this->DMRGint[is]);
292-
}
293-
delete this->dm2d_tmp;
294-
this->dm2d_tmp = nullptr;
295+
#ifdef __MPI
296+
hamilt::transferParallels2Serials( *(this->dm2d_tmp), this->dmr_gint[is]);
295297
#else
296-
//this->DMRGint_full = DM2D[0];
298+
this->dmr_gint[is]->set_zero();
299+
this->dmr_gint[is]->add(*(this->dm2d_tmp));
297300
#endif
301+
}//is=4
302+
delete this->dm2d_tmp;
303+
this->dm2d_tmp = nullptr;
298304
}
299305
ModuleBase::timer::tick("Gint", "transfer_DMR");
300306
}

0 commit comments

Comments
 (0)