Skip to content

Commit f2fb9af

Browse files
zgn-26714dyzheng
andauthored
Removed the temporary variable DMRGint_full when transitioning from 2D block parallelism to serial in Hcontainer. (#6487)
* Fixed the bug in memory statistics * delete tem Hcontainer to reduce memory usage --------- Co-authored-by: dyzheng <[email protected]>
1 parent 52457bf commit f2fb9af

File tree

3 files changed

+70
-46
lines changed

3 files changed

+70
-46
lines changed

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Gint::~Gint() {
3333
delete this->hRGint_tmp[is];
3434
}
3535
#ifdef __MPI
36-
delete this->DMRGint_full;
36+
delete this->DM2D_tmp;
3737
#endif
3838
}
3939

@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172172
}
173173
#ifdef __MPI
174-
if (this->DMRGint_full != nullptr) {
175-
delete this->DMRGint_full;
174+
if (this->DM2D_tmp != nullptr) {
175+
delete this->DM2D_tmp;
176176
}
177-
this->DMRGint_full = new hamilt::HContainer<double>(ucell_in.nat);
178177
#endif
179178
}
180179

@@ -199,6 +198,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
199198
} else {
200199
this->hRGintCd->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
201200
this->hRGintCd->allocate(nullptr, true);
201+
ModuleBase::Memory::record("Gint::hRGintCd",
202+
this->hRGintCd->get_memory_size());
202203
for(int is = 0; is < nspin; is++) {
203204
this->hRGint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
204205
this->DMRGint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
@@ -208,14 +209,7 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
208209
ModuleBase::Memory::record("Gint::hRGint_tmp",
209210
this->hRGint_tmp[0]->get_memory_size()*nspin);
210211
ModuleBase::Memory::record("Gint::DMRGint",
211-
this->DMRGint[0]->get_memory_size()
212-
* this->DMRGint.size()*nspin);
213-
#ifdef __MPI
214-
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
215-
this->DMRGint_full->allocate(nullptr, true);
216-
ModuleBase::Memory::record("Gint::DMRGint_full",
217-
this->DMRGint_full->get_memory_size());
218-
#endif
212+
this->DMRGint[0]->get_memory_size()*nspin);
219213
}
220214
}
221215

@@ -231,10 +225,8 @@ void Gint::reset_DMRGint(const int& nspin)
231225
{
232226
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
233227
#ifdef __MPI
234-
delete this->DMRGint_full;
235-
this->DMRGint_full = new hamilt::HContainer<double>(*this->hRGint);
236-
this->DMRGint_full->allocate(nullptr, false);
237-
#endif
228+
delete this->DM2D_tmp;
229+
#endif
238230
}
239231
}
240232
}
@@ -262,37 +254,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262254
} else // NSPIN=4 case
263255
{
264256
#ifdef __MPI
265-
hamilt::transferParallels2Serials(*DM2D[0], this->DMRGint_full);
266-
#else
267-
this->DMRGint_full = DM2D[0];
268-
#endif
269-
std::vector<double*> tmp_pointer(4, nullptr);
270-
for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
271-
auto& ap = this->DMRGint_full->get_atom_pair(iap);
272-
int iat1 = ap.get_atom_i();
273-
int iat2 = ap.get_atom_j();
274-
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
275-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
276-
for (int is = 0; is < 4; is++) {
277-
tmp_pointer[is] = this->DMRGint[is]
278-
->find_matrix(iat1, iat2, r_index)
279-
->get_pointer();
280-
}
281-
double* data_full = ap.get_pointer(ir);
282-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
283-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284-
*(tmp_pointer[0])++ = data_full[icol];
285-
*(tmp_pointer[1])++ = data_full[icol + 1];
286-
}
287-
data_full += ap.get_col_size();
288-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289-
*(tmp_pointer[2])++ = data_full[icol];
290-
*(tmp_pointer[3])++ = data_full[icol + 1];
257+
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
258+
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
259+
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
260+
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
261+
int *iat2iwt = new int[ucell->nat];
262+
for (int iat = 0; iat < ucell->nat; iat++) {
263+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
264+
}
265+
Parallel_Orbitals *pv = new Parallel_Orbitals();
266+
pv->set(mg, ng, nb, blacs_ctxt);
267+
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
268+
auto ijr_info = DM2D[0]->get_ijr_info();
269+
this-> DM2D_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
270+
ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
271+
for (int is = 0; is < 4; is++){
272+
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
273+
auto& ap = DM2D[0]->get_atom_pair(iap);
274+
int iat1 = ap.get_atom_i();
275+
int iat2 = ap.get_atom_j();
276+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
277+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
278+
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
279+
double* data_full = ap.get_pointer(ir);
280+
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
281+
switch (is) {//todo: It can be written more compactly
282+
case 0:
283+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284+
*(tmp_pointer)++ = data_full[icol];
285+
}
286+
data_full += ap.get_col_size() * 2;
287+
break;
288+
case 1:
289+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
290+
*(tmp_pointer)++ = data_full[icol + 1];
291+
}
292+
data_full += ap.get_col_size() * 2;
293+
break;
294+
case 2:
295+
data_full += ap.get_col_size();
296+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
297+
*(tmp_pointer)++ = data_full[icol];
298+
}
299+
data_full += ap.get_col_size();
300+
break;
301+
case 3:
302+
data_full += ap.get_col_size();
303+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
304+
*(tmp_pointer)++ = data_full[icol + 1];
305+
}
306+
data_full += ap.get_col_size();
307+
break;
308+
}
291309
}
292-
data_full += ap.get_col_size();
293310
}
294311
}
312+
hamilt::transferParallels2Serials( *(this->DM2D_tmp), this->DMRGint[is]);
295313
}
314+
// delete iat2iwt;
315+
// delete pv;
316+
// delete this-> DM2D_tmp;
317+
#else
318+
//this->DMRGint_full = DM2D[0];
319+
#endif
296320
}
297321
ModuleBase::timer::tick("Gint", "transfer_DMR");
298322
}

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class Gint {
264264
std::vector<hamilt::HContainer<double>*> DMRGint;
265265

266266
//! tmp tools used in transfer_DM2DtoGrid
267-
hamilt::HContainer<double>* DMRGint_full = nullptr;
267+
hamilt::HContainer<double>* DM2D_tmp = nullptr;
268268

269269
std::vector<hamilt::HContainer<double>> pvdpRx_reduced;
270270
std::vector<hamilt::HContainer<double>> pvdpRy_reduced;

source/module_lr/utils/gint_move.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ Gint& Gint::operator=(Gint&& rhs)
6060
this->pvdpRz_reduced = std::move(rhs.pvdpRz_reduced);
6161
this->DMRGint = std::move(rhs.DMRGint);
6262
this->hRGint_tmp = std::move(rhs.hRGint_tmp);
63-
this->DMRGint_full = rhs.DMRGint_full;
64-
rhs.DMRGint_full = nullptr;
63+
this->DM2D_tmp = rhs.DM2D_tmp;
64+
rhs.DM2D_tmp = nullptr;
6565

6666
return *this;
6767
}

0 commit comments

Comments
 (0)