Skip to content

Commit 0817e32

Browse files
authored
Removed the temporary variable DMRGint_full when transitioning from 2D block parallelism to serial in Hcontainer(develop) (#6489)
* delete tem Hcontainer to reduce memory usage * simplify the compute code * change DM2D_tmp to dm2d_tmp, use vector instead of new
1 parent 305bbf6 commit 0817e32

File tree

5 files changed

+80
-77
lines changed

5 files changed

+80
-77
lines changed

source/source_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Gint {
265265
std::vector<hamilt::HContainer<double>*> DMRGint;
266266

267267
//! tmp tools used in transfer_DM2DtoGrid
268-
hamilt::HContainer<double>* DMRGint_full = nullptr;
268+
hamilt::HContainer<double>* dm2d_tmp = nullptr;
269269

270270
std::vector<hamilt::HContainer<double>> pvdpRx_reduced;
271271
std::vector<hamilt::HContainer<double>> pvdpRy_reduced;

source/source_lcao/module_gint/gint_old.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Gint::~Gint() {
3333
delete this->hRGint_tmp[is];
3434
}
3535
#ifdef __MPI
36-
delete this->DMRGint_full;
36+
delete this->dm2d_tmp;
3737
#endif
3838
}
3939

@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171
this->hRGint_tmp[is] = new hamilt::HContainer<double>(ucell_in.nat);
172172
}
173173
#ifdef __MPI
174-
if (this->DMRGint_full != nullptr) {
175-
delete this->DMRGint_full;
174+
if (this->dm2d_tmp != nullptr) {
175+
delete this->dm2d_tmp;
176176
}
177-
this->DMRGint_full = new hamilt::HContainer<double>(ucell_in.nat);
178177
#endif
179178
}
180179

@@ -210,12 +209,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
210209
ModuleBase::Memory::record("Gint::DMRGint",
211210
this->DMRGint[0]->get_memory_size()
212211
* this->DMRGint.size()*nspin);
213-
#ifdef __MPI
214-
this->DMRGint_full->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
215-
this->DMRGint_full->allocate(nullptr, true);
216-
ModuleBase::Memory::record("Gint::DMRGint_full",
217-
this->DMRGint_full->get_memory_size());
218-
#endif
219212
}
220213
}
221214

@@ -231,9 +224,7 @@ void Gint::reset_DMRGint(const int& nspin)
231224
{
232225
for (auto& d : this->DMRGint) { d->allocate(nullptr, false); }
233226
#ifdef __MPI
234-
delete this->DMRGint_full;
235-
this->DMRGint_full = new hamilt::HContainer<double>(*this->hRGint);
236-
this->DMRGint_full->allocate(nullptr, false);
227+
delete this->dm2d_tmp;
237228
#endif
238229
}
239230
}
@@ -262,37 +253,46 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262253
} else // NSPIN=4 case
263254
{
264255
#ifdef __MPI
265-
hamilt::transferParallels2Serials(*DM2D[0], this->DMRGint_full);
266-
#else
267-
this->DMRGint_full = DM2D[0];
268-
#endif
269-
std::vector<double*> tmp_pointer(4, nullptr);
270-
for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
271-
auto& ap = this->DMRGint_full->get_atom_pair(iap);
272-
int iat1 = ap.get_atom_i();
273-
int iat2 = ap.get_atom_j();
274-
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
275-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
276-
for (int is = 0; is < 4; is++) {
277-
tmp_pointer[is] = this->DMRGint[is]
278-
->find_matrix(iat1, iat2, r_index)
279-
->get_pointer();
280-
}
281-
double* data_full = ap.get_pointer(ir);
282-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
283-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284-
*(tmp_pointer[0])++ = data_full[icol];
285-
*(tmp_pointer[1])++ = data_full[icol + 1];
286-
}
287-
data_full += ap.get_col_size();
288-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289-
*(tmp_pointer[2])++ = data_full[icol];
290-
*(tmp_pointer[3])++ = data_full[icol + 1];
256+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257+
const int row_set[4] = {0, 0, 1, 1};
258+
const int col_set[4] = {0, 1, 0, 1};
259+
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
260+
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
261+
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
262+
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
263+
std::vector<int> iat2iwt(ucell->nat);
264+
for (int iat = 0; iat < ucell->nat; iat++) {
265+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
266+
}
267+
Parallel_Orbitals *pv = new Parallel_Orbitals();
268+
pv->set(mg, ng, nb, blacs_ctxt);
269+
pv->set_atomic_trace(iat2iwt.data(), ucell->nat, mg);
270+
auto ijr_info = DM2D[0]->get_ijr_info();
271+
this-> dm2d_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
272+
ModuleBase::Memory::record("Gint::dm2d_tmp", this->dm2d_tmp->get_memory_size());
273+
for (int is = 0; is < 4; is++){
274+
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
275+
auto& ap = DM2D[0]->get_atom_pair(iap);
276+
int iat1 = ap.get_atom_i();
277+
int iat2 = ap.get_atom_j();
278+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
279+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
280+
double* matrix_out = this -> dm2d_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
281+
double* matrix_in = ap.get_pointer(ir);
282+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
283+
for (int icol = 0; icol < ap.get_col_size()/2; icol++){
284+
int index_i = irow* ap.get_col_size()/2 + icol;
285+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
286+
matrix_out[index_i] = matrix_in[index_j];
287+
}
291288
}
292-
data_full += ap.get_col_size();
293289
}
294290
}
291+
hamilt::transferParallels2Serials( *(this->dm2d_tmp), this->DMRGint[is]);
295292
}
293+
#else
294+
//this->DMRGint_full = DM2D[0];
295+
#endif
296296
}
297297
ModuleBase::timer::tick("Gint", "transfer_DMR");
298298
}

source/source_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -163,44 +163,46 @@ void transfer_dm_2d_to_gint(
163163
} else // NSPIN=4 case
164164
{
165165
#ifdef __MPI
166-
const int npol = 2;
167-
HContainer<T> dm_full = gint_info.get_hr<T>(npol);
168-
hamilt::transferParallels2Serials(*dm[0], &dm_full);
169-
#else
170-
HContainer<T>& dm_full = *(dm[0]);
171-
#endif
172-
std::vector<T*> tmp_pointer(4, nullptr);
173-
for (int iap = 0; iap < dm_full.size_atom_pairs(); iap++)
174-
{
175-
auto& ap = dm_full.get_atom_pair(iap);
176-
const int iat1 = ap.get_atom_i();
177-
const int iat2 = ap.get_atom_j();
178-
for (int ir = 0; ir < ap.get_R_size(); ir++)
179-
{
180-
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
181-
for (int is = 0; is < 4; is++)
182-
{
183-
tmp_pointer[is] =
184-
dm_gint[is].find_matrix(iat1, iat2, r_index)->get_pointer();
185-
}
186-
T* data_full = ap.get_pointer(ir);
187-
for (int irow = 0; irow < ap.get_row_size(); irow += 2)
188-
{
189-
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
190-
{
191-
*(tmp_pointer[0])++ = data_full[icol];
192-
*(tmp_pointer[1])++ = data_full[icol + 1];
193-
}
194-
data_full += ap.get_col_size();
195-
for (int icol = 0; icol < ap.get_col_size(); icol += 2)
196-
{
197-
*(tmp_pointer[2])++ = data_full[icol];
198-
*(tmp_pointer[3])++ = data_full[icol + 1];
166+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
167+
const int row_set[4] = {0, 0, 1, 1};
168+
const int col_set[4] = {0, 1, 0, 1};
169+
int mg = dm[0]->get_paraV()->get_global_row_size()/2;
170+
int ng = dm[0]->get_paraV()->get_global_col_size()/2;
171+
int nb = dm[0]->get_paraV()->get_block_size()/2;
172+
int blacs_ctxt = dm[0]->get_paraV()->blacs_ctxt;
173+
const UnitCell* ucell = gint_info.get_ucell();
174+
std::vector<int> iat2iwt(ucell->nat);
175+
for (int iat = 0; iat < ucell->nat; iat++) {
176+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
177+
}
178+
Parallel_Orbitals *pv = new Parallel_Orbitals();
179+
pv->set(mg, ng, nb, blacs_ctxt);
180+
pv->set_atomic_trace(iat2iwt.data(), ucell->nat, mg);
181+
auto ijr_info = dm[0]->get_ijr_info();
182+
HContainer<T>* dm2d_tmp = new hamilt::HContainer<T>(pv, nullptr, &ijr_info);
183+
for (int is = 0; is < 4; is++){
184+
for (int iap = 0; iap < dm[0]->size_atom_pairs(); ++iap) {
185+
auto& ap = dm[0]->get_atom_pair(iap);
186+
int iat1 = ap.get_atom_i();
187+
int iat2 = ap.get_atom_j();
188+
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
189+
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
190+
T* matrix_out = dm2d_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
191+
T* matrix_in = ap.get_pointer(ir);
192+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
193+
for (int icol = 0; icol < ap.get_col_size()/2; icol ++) {
194+
int index_i = irow* ap.get_col_size()/2 + icol;
195+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
196+
matrix_out[index_i] = matrix_in[index_j];
197+
}
199198
}
200-
data_full += ap.get_col_size();
201199
}
202200
}
201+
hamilt::transferParallels2Serials( *dm2d_tmp, &dm_gint[is]);
203202
}
203+
#else
204+
//HContainer<T>& dm_full = *(dm[0]);
205+
#endif
204206
}
205207
ModuleBase::timer::tick("Gint", "transfer_dm_2d_to_gint");
206208
}

source/source_lcao/module_gint/temp_gint/gint_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class GintInfo
3838
const std::vector<int>& get_trace_lo() const{ return trace_lo_; }
3939
int get_lgd() const { return lgd_; }
4040
int get_nat() const { return ucell_->nat; } // return the number of atoms in the unitcell
41+
const UnitCell* get_ucell() const { return ucell_; }
4142
int get_local_mgrid_num() const { return localcell_info_->get_mgrids_num(); }
4243
double get_mgrid_volume() const { return meshgrid_info_->get_volume(); }
4344

source/source_lcao/module_lr/utils/gint_move.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ Gint& Gint::operator=(Gint&& rhs)
6060
this->pvdpRz_reduced = std::move(rhs.pvdpRz_reduced);
6161
this->DMRGint = std::move(rhs.DMRGint);
6262
this->hRGint_tmp = std::move(rhs.hRGint_tmp);
63-
this->DMRGint_full = rhs.DMRGint_full;
64-
rhs.DMRGint_full = nullptr;
63+
this->dm2d_tmp = rhs.dm2d_tmp;
64+
rhs.dm2d_tmp = nullptr;
6565

6666
return *this;
6767
}

0 commit comments

Comments
 (0)