Skip to content

Commit 07a9a23

Browse files
committed
delete tem Hcontainer to reduce memory usage
1 parent 44b2136 commit 07a9a23

File tree

3 files changed

+67
-44
lines changed

3 files changed

+67
-44
lines changed

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Gint::~Gint() {
3333
delete this->hRGint_tmp[is];
3434
}
3535
#ifdef __MPI
36-
delete this->DMRGint_full;
36+
delete this->DM2D_tmp;
3737
#endif
3838
}
3939

@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172172
}
173173
#ifdef __MPI
174-
if (this->DMRGint_full != nullptr) {
175-
delete this->DMRGint_full;
174+
if (this->DM2D_tmp != nullptr) {
175+
delete this->DM2D_tmp;
176176
}
177-
this->DMRGint_full = new hamilt::HContainer<double>(ucell_in.nat);
178177
#endif
179178
}
180179

@@ -211,12 +210,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
211210
this->hRGint_tmp[0]->get_memory_size()*nspin);
212211
ModuleBase::Memory::record("Gint::DMRGint",
213212
this->DMRGint[0]->get_memory_size()*nspin);
214-
#ifdef __MPI
215-
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
216-
this->DMRGint_full->allocate(nullptr, true);
217-
ModuleBase::Memory::record("Gint::DMRGint_full",
218-
this->DMRGint_full->get_memory_size());
219-
#endif
220213
}
221214
}
222215

@@ -232,10 +225,8 @@ void Gint::reset_DMRGint(const int& nspin)
232225
{
233226
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
234227
#ifdef __MPI
235-
delete this->DMRGint_full;
236-
this->DMRGint_full = new hamilt::HContainer<double>(*this->hRGint);
237-
this->DMRGint_full->allocate(nullptr, false);
238-
#endif
228+
delete this->DM2D_tmp;
229+
#endif
239230
}
240231
}
241232
}
@@ -263,37 +254,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
263254
} else // NSPIN=4 case
264255
{
265256
#ifdef __MPI
266-
hamilt::transferParallels2Serials(*DM2D[0], this->DMRGint_full);
267-
#else
268-
this->DMRGint_full = DM2D[0];
269-
#endif
270-
std::vector<double*> tmp_pointer(4, nullptr);
271-
for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
272-
auto& ap = this->DMRGint_full->get_atom_pair(iap);
273-
int iat1 = ap.get_atom_i();
274-
int iat2 = ap.get_atom_j();
275-
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
276-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
277-
for (int is = 0; is < 4; is++) {
278-
tmp_pointer[is] = this->DMRGint[is]
279-
->find_matrix(iat1, iat2, r_index)
280-
->get_pointer();
281-
}
282-
double* data_full = ap.get_pointer(ir);
283-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
284-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
285-
*(tmp_pointer[0])++ = data_full[icol];
286-
*(tmp_pointer[1])++ = data_full[icol + 1];
287-
}
288-
data_full += ap.get_col_size();
289-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
290-
*(tmp_pointer[2])++ = data_full[icol];
291-
*(tmp_pointer[3])++ = data_full[icol + 1];
257+
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
258+
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
259+
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
260+
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
261+
int *iat2iwt = new int[ucell->nat];
262+
for (int iat = 0; iat < ucell->nat; iat++) {
263+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
264+
}
265+
Parallel_Orbitals *pv = new Parallel_Orbitals();
266+
pv->set(mg, ng, nb, blacs_ctxt);
267+
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
268+
auto ijr_info = DM2D[0]->get_ijr_info();
269+
this-> DM2D_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
270+
ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
271+
for (int is = 0; is < 4; is++){
272+
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
273+
auto& ap = DM2D[0]->get_atom_pair(iap);
274+
int iat1 = ap.get_atom_i();
275+
int iat2 = ap.get_atom_j();
276+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
277+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
278+
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
279+
double* data_full = ap.get_pointer(ir);
280+
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
281+
switch (is) {//todo: It can be written more compactly
282+
case 0:
283+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284+
*(tmp_pointer)++ = data_full[icol];
285+
}
286+
data_full += ap.get_col_size() * 2;
287+
break;
288+
case 1:
289+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
290+
*(tmp_pointer)++ = data_full[icol + 1];
291+
}
292+
data_full += ap.get_col_size() * 2;
293+
break;
294+
case 2:
295+
data_full += ap.get_col_size();
296+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
297+
*(tmp_pointer)++ = data_full[icol];
298+
}
299+
data_full += ap.get_col_size();
300+
break;
301+
case 3:
302+
data_full += ap.get_col_size();
303+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
304+
*(tmp_pointer)++ = data_full[icol + 1];
305+
}
306+
data_full += ap.get_col_size();
307+
break;
308+
}
292309
}
293-
data_full += ap.get_col_size();
294310
}
295311
}
312+
hamilt::transferParallels2Serials( *(this->DM2D_tmp), this->DMRGint[is]);
296313
}
314+
// delete iat2iwt;
315+
// delete pv;
316+
// delete this-> DM2D_tmp;
317+
#else
318+
//this->DMRGint_full = DM2D[0];
319+
#endif
297320
}
298321
ModuleBase::timer::tick("Gint", "transfer_DMR");
299322
}

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class Gint {
264264
std::vector<hamilt::HContainer<double>*> DMRGint;
265265

266266
//! tmp tools used in transfer_DM2DtoGrid
267-
hamilt::HContainer<double>* DMRGint_full = nullptr;
267+
hamilt::HContainer<double>* DM2D_tmp = nullptr;
268268

269269
std::vector<hamilt::HContainer<double>> pvdpRx_reduced;
270270
std::vector<hamilt::HContainer<double>> pvdpRy_reduced;

source/module_lr/utils/gint_move.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ Gint& Gint::operator=(Gint&& rhs)
6060
this->pvdpRz_reduced = std::move(rhs.pvdpRz_reduced);
6161
this->DMRGint = std::move(rhs.DMRGint);
6262
this->hRGint_tmp = std::move(rhs.hRGint_tmp);
63-
this->DMRGint_full = rhs.DMRGint_full;
64-
rhs.DMRGint_full = nullptr;
63+
this->DM2D_tmp = rhs.DM2D_tmp;
64+
rhs.DM2D_tmp = nullptr;
6565

6666
return *this;
6767
}

0 commit comments

Comments
 (0)