Skip to content

Commit 0f13efe

Browse files
zgn-26714dyzheng
andauthored
Removed the temporary variable hRGintCd when transitioning from 2D block parallelism to serial in Hcontainer. (#6488)
* Fixed the bug in memory statistics * delete tem Hcontainer to reduce memory usage * delete tem hRGintCd to reduce memory usage * fix parallel bug * Simplify the computational code * fix bug * simplify the compute code * Fix:error of lower_mat filling * Remove unnecessary comments, optimize calculation code * fix bug --------- Co-authored-by: dyzheng <[email protected]>
1 parent f2fb9af commit 0f13efe

File tree

4 files changed

+110
-110
lines changed

4 files changed

+110
-110
lines changed

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Gint::~Gint() {
2424

2525
delete this->hRGint;
26-
delete this->hRGintCd;
26+
delete this->hR_tmp;
2727
// in gamma_only case, DMRGint.size()=0,
2828
// in multi-k case, DMRGint.size()=nspin
2929
for (int is = 0; is < this->DMRGint.size(); is++) {
@@ -155,11 +155,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
155155
this->hRGint = new hamilt::HContainer<double>(ucell_in.nat);
156156
} else {
157157
npol = 2;
158-
if (this->hRGintCd != nullptr) {
159-
delete this->hRGintCd;
158+
if (this->hR_tmp != nullptr) {
159+
delete this->hR_tmp;
160160
}
161-
this->hRGintCd
162-
= new hamilt::HContainer<std::complex<double>>(ucell_in.nat);
163161
for (int is = 0; is < nspin; is++) {
164162
if (this->DMRGint[is] != nullptr) {
165163
delete this->DMRGint[is];
@@ -196,10 +194,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
196194
this->DMRGint[0]->get_memory_size()
197195
* this->DMRGint.size());
198196
} else {
199-
this->hRGintCd->insert_ijrs(this->gridt->get_ijr_info(), ucell_in, npol);
200-
this->hRGintCd->allocate(nullptr, true);
201-
ModuleBase::Memory::record("Gint::hRGintCd",
202-
this->hRGintCd->get_memory_size());
203197
for(int is = 0; is < nspin; is++) {
204198
this->hRGint_tmp[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
205199
this->DMRGint[is]->insert_ijrs(this->gridt->get_ijr_info(), ucell_in);
@@ -254,19 +248,24 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
254248
} else // NSPIN=4 case
255249
{
256250
#ifdef __MPI
251+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
252+
const int row_set[4] = {0, 0, 1, 1};
253+
const int col_set[4] = {0, 1, 0, 1};
257254
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
258255
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
259256
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
260257
int blacs_ctxt = DM2D[0]->get_paraV()->blacs_ctxt;
261-
int *iat2iwt = new int[ucell->nat];
258+
259+
std::vector<int> iat2iwt(ucell->nat);
262260
for (int iat = 0; iat < ucell->nat; iat++) {
263261
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
264262
}
265263
Parallel_Orbitals *pv = new Parallel_Orbitals();
266264
pv->set(mg, ng, nb, blacs_ctxt);
267-
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
265+
pv->set_atomic_trace(iat2iwt.data(), ucell->nat, mg);
268266
auto ijr_info = DM2D[0]->get_ijr_info();
269267
this-> DM2D_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
268+
this-> DM2D_tmp->set_zero();
270269
ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
271270
for (int is = 0; is < 4; is++){
272271
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
@@ -275,45 +274,19 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
275274
int iat2 = ap.get_atom_j();
276275
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
277276
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
278-
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
279-
double* data_full = ap.get_pointer(ir);
280-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
281-
switch (is) {//todo: It can be written more compactly
282-
case 0:
283-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
284-
*(tmp_pointer)++ = data_full[icol];
285-
}
286-
data_full += ap.get_col_size() * 2;
287-
break;
288-
case 1:
289-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
290-
*(tmp_pointer)++ = data_full[icol + 1];
291-
}
292-
data_full += ap.get_col_size() * 2;
293-
break;
294-
case 2:
295-
data_full += ap.get_col_size();
296-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
297-
*(tmp_pointer)++ = data_full[icol];
298-
}
299-
data_full += ap.get_col_size();
300-
break;
301-
case 3:
302-
data_full += ap.get_col_size();
303-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
304-
*(tmp_pointer)++ = data_full[icol + 1];
305-
}
306-
data_full += ap.get_col_size();
307-
break;
277+
double* matrix_out = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
278+
double* matrix_in = ap.get_pointer(ir);
279+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
280+
for (int icol = 0; icol < ap.get_col_size()/2; icol ++) {
281+
int index_i = irow* ap.get_col_size()/2 + icol;
282+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
283+
matrix_out[index_i] = matrix_in[index_j];
308284
}
309285
}
310286
}
311287
}
312288
hamilt::transferParallels2Serials( *(this->DM2D_tmp), this->DMRGint[is]);
313289
}
314-
// delete iat2iwt;
315-
// delete pv;
316-
// delete this-> DM2D_tmp;
317290
#else
318291
//this->DMRGint_full = DM2D[0];
319292
#endif

source/module_hamilt_lcao/module_gint/gint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class Gint {
258258
std::vector<hamilt::HContainer<double>*> hRGint_tmp;
259259

260260
//! stores Hamiltonian in sparse format
261-
hamilt::HContainer<std::complex<double>>* hRGintCd = nullptr;
261+
hamilt::HContainer<std::complex<double>>* hR_tmp = nullptr;
262262

263263
//! stores DMR in sparse format
264264
std::vector<hamilt::HContainer<double>*> DMRGint;

source/module_hamilt_lcao/module_gint/gint_k_pvpr.cpp

Lines changed: 90 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -78,86 +78,113 @@ void Gint_k::transfer_pvpR(hamilt::HContainer<std::complex<double>>* hR,
7878
ModuleBase::TITLE("Gint_k", "transfer_pvpR");
7979
ModuleBase::timer::tick("Gint_k", "transfer_pvpR");
8080

81-
this->hRGintCd->set_zero();
82-
83-
for (int iap = 0; iap < this->hRGintCd->size_atom_pairs(); iap++)
84-
{
85-
auto* ap = &this->hRGintCd->get_atom_pair(iap);
86-
const int iat1 = ap->get_atom_i();
87-
const int iat2 = ap->get_atom_j();
88-
if (iat1 <= iat2)
89-
{
90-
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
91-
hamilt::AtomPair<std::complex<double>>* lower_ap = this->hRGintCd->find_pair(iat2, iat1);
92-
const hamilt::AtomPair<double>* ap_nspin_0 = this->hRGint_tmp[0]->find_pair(iat1, iat2);
93-
const hamilt::AtomPair<double>* ap_nspin_3 = this->hRGint_tmp[3]->find_pair(iat1, iat2);
94-
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
95-
{
96-
const auto R_index = upper_ap->get_R_index(ir);
97-
auto upper_mat = upper_ap->find_matrix(R_index);
98-
auto mat_nspin_0 = ap_nspin_0->find_matrix(R_index);
99-
auto mat_nspin_3 = ap_nspin_3->find_matrix(R_index);
81+
int mg = hR->get_paraV()->get_global_row_size()/2;
82+
int ng = hR->get_paraV()->get_global_col_size()/2;
83+
int nb = hR->get_paraV()->get_block_size()/2;
84+
#ifdef __MPI
85+
int blacs_ctxt = hR->get_paraV()->blacs_ctxt;
86+
std::vector<int> iat2iwt(ucell_in->nat);
87+
for (int iat = 0; iat < ucell_in->nat; iat++) {
88+
iat2iwt[iat] = ucell_in->get_iat2iwt()[iat]/2;
89+
}
90+
Parallel_Orbitals *pv = new Parallel_Orbitals();
91+
pv->set(mg, ng, nb, blacs_ctxt);
92+
pv->set_atomic_trace(iat2iwt.data(), ucell_in->nat, mg);
93+
auto ijr_info = hR->get_ijr_info();
10094

101-
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
102-
for (int irow = 0; irow < mat_nspin_0->get_row_size(); ++irow)
103-
{
104-
for (int icol = 0; icol < mat_nspin_0->get_col_size(); ++icol)
95+
this->hR_tmp = new hamilt::HContainer<std::complex<double>>(pv, nullptr, &ijr_info);
96+
ModuleBase::Memory::record("Gint::hRGintCd", this->hR_tmp->get_memory_size());
97+
98+
//select hRGint_tmp
99+
std::vector<int> first = {0, 1, 1, 0};
100+
std::vector<int> second= {3, 2, 2, 3};
101+
//select position in the big matrix
102+
std::vector<int> row_set = {0, 0, 1, 1};
103+
std::vector<int> col_set = {0, 1, 0, 1};
104+
//construct complex matrix
105+
std::vector<int> clx_i = {1, 0, 0, -1};
106+
std::vector<int> clx_j = {0, 1, -1, 0};
107+
for (int is = 0; is < 4; is++){
108+
if(!PARAM.globalv.domag && (is==1 || is==2)) continue;
109+
this->hR_tmp->set_zero();
110+
hamilt::HContainer<std::complex<double>>* hRGint_tmpCd = new hamilt::HContainer<std::complex<double>>(this->ucell->nat);
111+
hRGint_tmpCd->insert_ijrs(this->gridt->get_ijr_info(), *(this->ucell));
112+
hRGint_tmpCd->allocate(nullptr, true);
113+
hRGint_tmpCd->set_zero();
114+
for (int iap = 0; iap < hRGint_tmpCd->size_atom_pairs(); iap++)
115+
{
116+
auto* ap = &hRGint_tmpCd->get_atom_pair(iap);
117+
const int iat1 = ap->get_atom_i();
118+
const int iat2 = ap->get_atom_j();
119+
if (iat1 <= iat2)
120+
{
121+
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
122+
hamilt::AtomPair<std::complex<double>>* lower_ap = hRGint_tmpCd->find_pair(iat2, iat1);
123+
const hamilt::AtomPair<double>* ap_nspin1 = this->hRGint_tmp[first[is]] ->find_pair(iat1, iat2);
124+
const hamilt::AtomPair<double>* ap_nspin2 = this->hRGint_tmp[second[is]] ->find_pair(iat1, iat2);
125+
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
126+
{
127+
const auto R_index = upper_ap->get_R_index(ir);
128+
auto upper_mat = upper_ap->find_matrix(R_index);
129+
auto mat_nspin1 = ap_nspin1->find_matrix(R_index);
130+
auto mat_nspin2 = ap_nspin2->find_matrix(R_index);
131+
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
132+
for (int irow = 0; irow < mat_nspin1->get_row_size(); ++irow)
105133
{
106-
upper_mat->get_value(2*irow, 2*icol) = mat_nspin_0->get_value(irow, icol) + mat_nspin_3->get_value(irow, icol);
107-
upper_mat->get_value(2*irow+1, 2*icol+1) = mat_nspin_0->get_value(irow, icol) - mat_nspin_3->get_value(irow, icol);
134+
for (int icol = 0; icol < mat_nspin1->get_col_size(); ++icol)
135+
{
136+
upper_mat->get_value(irow, icol) = mat_nspin1->get_value(irow, icol)
137+
+ std::complex<double>(clx_i[is], clx_j[is]) * mat_nspin2->get_value(irow, icol);
138+
}
108139
}
109-
}
110-
111-
if (PARAM.globalv.domag)
112-
{
113-
const hamilt::AtomPair<double>* ap_nspin_1 = this->hRGint_tmp[1]->find_pair(iat1, iat2);
114-
const hamilt::AtomPair<double>* ap_nspin_2 = this->hRGint_tmp[2]->find_pair(iat1, iat2);
115-
const auto mat_nspin_1 = ap_nspin_1->find_matrix(R_index);
116-
const auto mat_nspin_2 = ap_nspin_2->find_matrix(R_index);
117-
for (int irow = 0; irow < mat_nspin_1->get_row_size(); ++irow)
140+
//fill the lower triangle matrix
141+
//When is=0 or 3, the real part does not need conjugation;
142+
//when is=1 or 2, the small matrix is not Hermitian, so conjugation is not needed
143+
if (iat1 < iat2)
118144
{
119-
for (int icol = 0; icol < mat_nspin_1->get_col_size(); ++icol)
145+
auto lower_mat = lower_ap->find_matrix(-R_index);
146+
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
120147
{
121-
upper_mat->get_value(2*irow, 2*icol+1) = mat_nspin_1->get_value(irow, icol) + std::complex<double>(0.0, 1.0) * mat_nspin_2->get_value(irow, icol);
122-
upper_mat->get_value(2*irow+1, 2*icol) = mat_nspin_1->get_value(irow, icol) - std::complex<double>(0.0, 1.0) * mat_nspin_2->get_value(irow, icol);
148+
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
149+
{
150+
lower_mat->get_value(icol, irow) = upper_mat->get_value(irow, icol);
151+
}
123152
}
124153
}
125-
}
126154

127-
// fill the lower triangle matrix
128-
if (iat1 < iat2)
155+
}
156+
}
157+
}
158+
// transfer hRGint_tmpCd to parallel hR_tmp
159+
hamilt::transferSerials2Parallels( *hRGint_tmpCd, this->hR_tmp);
160+
// merge hR_tmp to hR
161+
for (int iap = 0; iap < hR->size_atom_pairs(); iap++)
162+
{
163+
auto* ap = &hR->get_atom_pair(iap);
164+
const int iat1 = ap->get_atom_i();
165+
const int iat2 = ap->get_atom_j();
166+
auto* ap_nspin = this->hR_tmp ->find_pair(iat1, iat2);
167+
for (int ir = 0; ir < ap->get_R_size(); ir++)
168+
{
169+
const auto R_index = ap->get_R_index(ir);
170+
auto upper_mat = ap->find_matrix(R_index);
171+
auto mat_nspin = ap_nspin->find_matrix(R_index);
172+
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
173+
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
129174
{
130-
auto lower_mat = lower_ap->find_matrix(-R_index);
131-
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
175+
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
132176
{
133-
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
134-
{
135-
lower_mat->get_value(icol, irow) = conj(upper_mat->get_value(irow, icol));
136-
}
177+
upper_mat->get_value(2*irow+row_set[is], 2*icol+col_set[is]) =
178+
mat_nspin->get_value(irow, icol);
137179
}
138180
}
139181
}
140182
}
141-
}
142-
143-
// ===================================
144-
// transfer HR from Gint to Veff<OperatorLCAO<std::complex<double>, std::complex<double>>>
145-
// ===================================
146-
#ifdef __MPI
147-
int size;
148-
MPI_Comm_size(MPI_COMM_WORLD, &size);
149-
if (size == 1)
150-
{
151-
hR->add(*this->hRGintCd);
152-
}
153-
else
154-
{
155-
hamilt::transferSerials2Parallels<std::complex<double>>(*this->hRGintCd, hR);
183+
delete hRGint_tmpCd;
156184
}
157185
#else
158-
hR->add(*this->hRGintCd);
159-
#endif
160186

187+
#endif
161188
ModuleBase::timer::tick("Gint_k", "transfer_pvpR");
162189
return;
163190
}

source/module_lr/utils/gint_move.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Gint& Gint::operator=(Gint&& rhs)
4545
// move hR after refactor
4646
this->hRGint = rhs.hRGint;
4747
rhs.hRGint = nullptr;
48-
this->hRGintCd = rhs.hRGintCd;
49-
rhs.hRGintCd = nullptr;
48+
this->hR_tmp = rhs.hR_tmp;
49+
rhs.hR_tmp = nullptr;
5050
for (int i = 0; i < this->DMRGint.size(); i++)
5151
{
5252
delete this->DMRGint[i];

0 commit comments

Comments
 (0)