Skip to content

Commit 7d4fe5a

Browse files
committed
delete tem Hcontainer to reduce memory usage
1 parent 8a3ea2a commit 7d4fe5a

File tree

5 files changed

+121
-77
lines changed

5 files changed

+121
-77
lines changed

source/source_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Gint {
265265
std::vector<hamilt::HContainer<double>*> DMRGint;
266266

267267
//! tmp tools used in transfer_DM2DtoGrid
268-
hamilt::HContainer<double>* DMRGint_full = nullptr;
268+
hamilt::HContainer<double>* DM2D_tmp = nullptr;
269269

270270
std::vector<hamilt::HContainer<double>> pvdpRx_reduced;
271271
std::vector<hamilt::HContainer<double>> pvdpRy_reduced;

source/source_lcao/module_gint/gint_old.cpp

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Gint::~Gint() {
3333
delete this->hRGint_tmp[is];
3434
}
3535
#ifdef __MPI
36-
delete this->DMRGint_full;
36+
delete this->DM2D_tmp;
3737
#endif
3838
}
3939

@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172172
}
173173
#ifdef __MPI
174-
if (this->DMRGint_full != nullptr) {
175-
delete this->DMRGint_full;
174+
if (this->DM2D_tmp != nullptr) {
175+
delete this->DM2D_tmp;
176176
}
177-
this->DMRGint_full = new hamilt::HContainer<double>(ucell_in.nat);
178177
#endif
179178
}
180179

@@ -210,12 +209,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
210209
ModuleBase::Memory::record("Gint::DMRGint",
211210
this->DMRGint[0]->get_memory_size()
212211
* this->DMRGint.size()*nspin);
213-
#ifdef __MPI
214-
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
215-
this->DMRGint_full->allocate(nullptr, true);
216-
ModuleBase::Memory::record("Gint::DMRGint_full",
217-
this->DMRGint_full->get_memory_size());
218-
#endif
219212
}
220213
}
221214

@@ -231,9 +224,7 @@ void Gint::reset_DMRGint(const int& nspin)
231224
{
232225
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
233226
#ifdef __MPI
234-
delete this->DMRGint_full;
235-
this->DMRGint_full = new hamilt::HContainer<double>(*this->hRGint);
236-
this->DMRGint_full->allocate(nullptr, false);
227+
delete this->DM2D_tmp;
237228
#endif
238229
}
239230
}
@@ -262,37 +253,66 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262253
} else // NSPIN=4 case
263254
{
264255
#ifdef __MPI
265-
hamilt::transferParallels2Serials(*DM2D[0], this->DMRGint_full);
266-
#else
267-
this->DMRGint_full = DM2D[0];
268-
#endif
269-
std::vector<double*> tmp_pointer(4, nullptr);
270-
for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
271-
auto& ap = this->DMRGint_full->get_atom_pair(iap);
272-
int iat1 = ap.get_atom_i();
273-
int iat2 = ap.get_atom_j();
274-
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
275-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
276-
for (int is = 0; is < 4; is++) {
277-
tmp_pointer[is] = this->DMRGint[is]
278-
->find_matrix(iat1, iat2, r_index)
279-
->get_pointer();
280-
}
281-
double* data_full = ap.get_pointer(ir);
282-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
283-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284-
*(tmp_pointer[0])++ = data_full[icol];
285-
*(tmp_pointer[1])++ = data_full[icol + 1];
286-
}
287-
data_full += ap.get_col_size();
288-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289-
*(tmp_pointer[2])++ = data_full[icol];
290-
*(tmp_pointer[3])++ = data_full[icol + 1];
256+
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
257+
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
258+
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
259+
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
260+
int *iat2iwt = new int[ucell->nat];
261+
for (int iat = 0; iat < ucell->nat; iat++) {
262+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
263+
}
264+
Parallel_Orbitals *pv = new Parallel_Orbitals();
265+
pv->set(mg, ng, nb, blacs_ctxt);
266+
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
267+
auto ijr_info = DM2D[0]->get_ijr_info();
268+
this-> DM2D_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
269+
ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
270+
for (int is = 0; is < 4; is++){
271+
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
272+
auto& ap = DM2D[0]->get_atom_pair(iap);
273+
int iat1 = ap.get_atom_i();
274+
int iat2 = ap.get_atom_j();
275+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
276+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
277+
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
278+
double* data_full = ap.get_pointer(ir);
279+
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
280+
switch (is) {//todo: It can be written more compactly
281+
case 0:
282+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
283+
*(tmp_pointer)++ = data_full[icol];
284+
}
285+
data_full += ap.get_col_size() * 2;
286+
break;
287+
case 1:
288+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289+
*(tmp_pointer)++ = data_full[icol + 1];
290+
}
291+
data_full += ap.get_col_size() * 2;
292+
break;
293+
case 2:
294+
data_full += ap.get_col_size();
295+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
296+
*(tmp_pointer)++ = data_full[icol];
297+
}
298+
data_full += ap.get_col_size();
299+
break;
300+
case 3:
301+
data_full += ap.get_col_size();
302+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
303+
*(tmp_pointer)++ = data_full[icol + 1];
304+
}
305+
data_full += ap.get_col_size();
306+
break;
307+
}
291308
}
292-
data_full += ap.get_col_size();
293309
}
294310
}
311+
hamilt::transferParallels2Serials( *(this->DM2D_tmp), this->DMRGint[is]);
295312
}
313+
#else
314+
//this->DMRGint_full = DM2D[0];
315+
#endif
296316
}
297317
ModuleBase::timer::tick("Gint", "transfer_DMR");
298318
}

source/source_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -163,44 +163,67 @@ void transfer_dm_2d_to_gint(
163163
} else // NSPIN=4 case
164164
{
165165
#ifdef __MPI
166-
const int npol = 2;
167-
HContainer<T> dm_full = gint_info.get_hr<T>(npol);
168-
hamilt::transferParallels2Serials(*dm[0], &dm_full);
169-
#else
170-
HContainer<T>& dm_full = *(dm[0]);
171-
#endif
172-
std::vector<T*> tmp_pointer(4, nullptr);
173-
for (int iap = 0; iap < dm_full.size_atom_pairs(); iap++)
174-
{
175-
auto& ap = dm_full.get_atom_pair(iap);
176-
const int iat1 = ap.get_atom_i();
177-
const int iat2 = ap.get_atom_j();
178-
for (int ir = 0; ir < ap.get_R_size(); ir++)
179-
{
180-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
181-
for (int is = 0; is < 4; is++)
182-
{
183-
tmp_pointer[is] =
184-
dm_gint[is].find_matrix(iat1, iat2, r_index)->get_pointer();
185-
}
186-
T* data_full = ap.get_pointer(ir);
187-
for (int irow = 0; irow < ap.get_row_size(); irow += 2)
188-
{
189-
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
190-
{
191-
*(tmp_pointer[0])++ = data_full[icol];
192-
*(tmp_pointer[1])++ = data_full[icol + 1];
193-
}
194-
data_full += ap.get_col_size();
195-
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
196-
{
197-
*(tmp_pointer[2])++ = data_full[icol];
198-
*(tmp_pointer[3])++ = data_full[icol + 1];
166+
int mg = dm[0]->get_paraV()->get_global_row_size()/2;
167+
int ng = dm[0]->get_paraV()->get_global_col_size()/2;
168+
int nb = dm[0]->get_paraV()->get_block_size()/2;
169+
int blacs_ctxt = dm[0]->get_paraV()->blacs_ctxt;
170+
const UnitCell* ucell = gint_info.get_ucell();
171+
int *iat2iwt = new int[ucell->nat];
172+
for (int iat = 0; iat < ucell->nat; iat++) {
173+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
174+
}
175+
Parallel_Orbitals *pv = new Parallel_Orbitals();
176+
pv->set(mg, ng, nb, blacs_ctxt);
177+
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
178+
auto ijr_info = dm[0]->get_ijr_info();
179+
HContainer<T>* DM2D_tmp = new hamilt::HContainer<T>(pv, nullptr, &ijr_info);
180+
//ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
181+
for (int is = 0; is < 4; is++){
182+
for (int iap = 0; iap < dm[0]->size_atom_pairs(); ++iap) {
183+
auto& ap = dm[0]->get_atom_pair(iap);
184+
int iat1 = ap.get_atom_i();
185+
int iat2 = ap.get_atom_j();
186+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
187+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
188+
T* tmp_pointer = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
189+
T* data_full = ap.get_pointer(ir);
190+
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
191+
switch (is) {//todo: It can be written more compactly
192+
case 0:
193+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
194+
*(tmp_pointer)++ = data_full[icol];
195+
}
196+
data_full += ap.get_col_size() * 2;
197+
break;
198+
case 1:
199+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
200+
*(tmp_pointer)++ = data_full[icol + 1];
201+
}
202+
data_full += ap.get_col_size() * 2;
203+
break;
204+
case 2:
205+
data_full += ap.get_col_size();
206+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
207+
*(tmp_pointer)++ = data_full[icol];
208+
}
209+
data_full += ap.get_col_size();
210+
break;
211+
case 3:
212+
data_full += ap.get_col_size();
213+
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
214+
*(tmp_pointer)++ = data_full[icol + 1];
215+
}
216+
data_full += ap.get_col_size();
217+
break;
218+
}
199219
}
200-
data_full += ap.get_col_size();
201220
}
202221
}
222+
hamilt::transferParallels2Serials( *DM2D_tmp, &dm_gint[is]);
203223
}
224+
#else
225+
//HContainer<T>& dm_full = *(dm[0]);
226+
#endif
204227
}
205228
ModuleBase::timer::tick("Gint", "transfer_dm_2d_to_gint");
206229
}

source/source_lcao/module_gint/temp_gint/gint_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class GintInfo
3838
const std::vector<int>& get_trace_lo() const{ return trace_lo_; }
3939
int get_lgd() const { return lgd_; }
4040
int get_nat() const { return ucell_->nat; } // return the number of atoms in the unitcell
41+
const UnitCell* get_ucell() const { return ucell_; }
4142
int get_local_mgrid_num() const { return localcell_info_->get_mgrids_num(); }
4243
double get_mgrid_volume() const { return meshgrid_info_->get_volume(); }
4344

source/source_lcao/module_lr/utils/gint_move.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ Gint& Gint::operator=(Gint&& rhs)
6060
this->pvdpRz_reduced = std::move(rhs.pvdpRz_reduced);
6161
this->DMRGint = std::move(rhs.DMRGint);
6262
this->hRGint_tmp = std::move(rhs.hRGint_tmp);
63-
this->DMRGint_full = rhs.DMRGint_full;
64-
rhs.DMRGint_full = nullptr;
63+
this->DM2D_tmp = rhs.DM2D_tmp;
64+
rhs.DM2D_tmp = nullptr;
6565

6666
return *this;
6767
}

0 commit comments

Comments
 (0)