Skip to content

Commit b5a0f8f

Browse files
zgn-26714dyzheng
authored andcommitted
delete hr_gint_full_ to reduce memory usage
1 parent 7d4fe5a commit b5a0f8f

File tree

10 files changed

+143
-88
lines changed

10 files changed

+143
-88
lines changed

source/source_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 133 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -47,95 +47,163 @@ void compose_hr_gint(HContainer<double>& hr_gint)
4747
ModuleBase::timer::tick("Gint", "compose_hr_gint");
4848
}
4949

50-
void compose_hr_gint(const std::vector<HContainer<double>>& hr_gint_part,
51-
HContainer<std::complex<double>>& hr_gint_full)
50+
template <typename T>
51+
void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR)
5252
{
53-
ModuleBase::TITLE("Gint", "compose_hr_gint");
54-
ModuleBase::timer::tick("Gint", "compose_hr_gint");
55-
for (int iap = 0; iap < hr_gint_full.size_atom_pairs(); iap++)
53+
ModuleBase::TITLE("Gint", "transfer_hr_gint_to_hR");
54+
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR");
55+
#ifdef __MPI
56+
int size = 0;
57+
MPI_Comm_size(MPI_COMM_WORLD, &size);
58+
if (size == 1)
59+
{
60+
hR.add(hr_gint);
61+
}
62+
else
5663
{
57-
auto* ap = &(hr_gint_full.get_atom_pair(iap));
58-
const int iat1 = ap->get_atom_i();
59-
const int iat2 = ap->get_atom_j();
60-
if (iat1 <= iat2)
64+
hamilt::transferSerials2Parallels(hr_gint, &hR);
65+
}
66+
#else
67+
hR.add(hr_gint);
68+
#endif
69+
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR");
70+
}
71+
72+
//hRgint_tmp to hR
73+
void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
74+
HContainer<std::complex<double>>& hR,
75+
const GintInfo& gint_info)
76+
{
77+
ModuleBase::TITLE("Gint", "transfer_hr_gint_to_hR_nspin4");
78+
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR_nspin4");
79+
#ifdef __MPI
80+
int mg = hR.get_paraV()->get_global_row_size()/2;
81+
int ng = hR.get_paraV()->get_global_col_size()/2;
82+
int nb = hR.get_paraV()->get_block_size()/2;
83+
int blacs_ctxt = hR.get_paraV()->blacs_ctxt;
84+
const UnitCell* ucell = gint_info.get_ucell();
85+
int *iat2iwt = new int[ucell->nat];
86+
for (int iat = 0; iat < ucell->nat; iat++) {
87+
iat2iwt[iat] = ucell->get_iat2iwt()[iat]/2;
88+
}
89+
Parallel_Orbitals *pv = new Parallel_Orbitals();
90+
pv->set(mg, ng, nb, blacs_ctxt);
91+
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
92+
auto ijr_info = hR.get_ijr_info();
93+
94+
hamilt::HContainer<double>* hR_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
95+
for (int is = 0; is < 4; is++){
96+
hR_tmp->set_zero();
97+
//std::cout<<"is: "<<is<<std::endl;
98+
hamilt::transferSerials2Parallels( hRGint_tmp[is], hR_tmp);
99+
for (int iap = 0; iap < hR.size_atom_pairs(); iap++)
61100
{
62-
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
63-
hamilt::AtomPair<std::complex<double>>* lower_ap = hr_gint_full.find_pair(iat2, iat1);
64-
const hamilt::AtomPair<double>* ap_nspin_0 = hr_gint_part[0].find_pair(iat1, iat2);
65-
const hamilt::AtomPair<double>* ap_nspin_3 = hr_gint_part[3].find_pair(iat1, iat2);
66-
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
101+
//std::cout<<"iap: "<<iap<<std::endl;
102+
auto* ap = &hR.get_atom_pair(iap);
103+
const int iat1 = ap->get_atom_i();
104+
const int iat2 = ap->get_atom_j();
105+
const hamilt::AtomPair<double>* ap_nspin = nullptr;
106+
if (iat1 <= iat2)
67107
{
68-
const auto R_index = upper_ap->get_R_index(ir);
69-
auto upper_mat = upper_ap->find_matrix(R_index);
70-
auto mat_nspin_0 = ap_nspin_0->find_matrix(R_index);
71-
auto mat_nspin_3 = ap_nspin_3->find_matrix(R_index);
72-
73-
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
74-
for (int irow = 0; irow < mat_nspin_0->get_row_size(); ++irow)
108+
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
109+
hamilt::AtomPair<std::complex<double>>* lower_ap = hR.find_pair(iat2, iat1);
110+
switch (is)
75111
{
76-
for (int icol = 0; icol < mat_nspin_0->get_col_size(); ++icol)
77-
{
78-
upper_mat->get_value(2*irow, 2*icol) = mat_nspin_0->get_value(irow, icol) + mat_nspin_3->get_value(irow, icol);
79-
upper_mat->get_value(2*irow+1, 2*icol+1) = mat_nspin_0->get_value(irow, icol) - mat_nspin_3->get_value(irow, icol);
80-
}
112+
case 0:
113+
ap_nspin = hR_tmp->find_pair(iat1, iat2);
114+
break;
115+
case 3:
116+
ap_nspin = hR_tmp->find_pair(iat1, iat2);
117+
break;
81118
}
119+
if(ap_nspin == nullptr) break;
120+
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
121+
{
122+
const auto R_index = upper_ap->get_R_index(ir);
123+
auto upper_mat = upper_ap->find_matrix(R_index);
124+
auto mat_nspin = ap_nspin->find_matrix(R_index);
82125

83-
if (PARAM.globalv.domag)
84-
{
85-
const hamilt::AtomPair<double>* ap_nspin_1 = hr_gint_part[1].find_pair(iat1, iat2);
86-
const hamilt::AtomPair<double>* ap_nspin_2 = hr_gint_part[2].find_pair(iat1, iat2);
87-
const auto mat_nspin_1 = ap_nspin_1->find_matrix(R_index);
88-
const auto mat_nspin_2 = ap_nspin_2->find_matrix(R_index);
89-
for (int irow = 0; irow < mat_nspin_1->get_row_size(); ++irow)
126+
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
127+
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
90128
{
91-
for (int icol = 0; icol < mat_nspin_1->get_col_size(); ++icol)
129+
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
92130
{
93-
upper_mat->get_value(2*irow, 2*icol+1) = mat_nspin_1->get_value(irow, icol) + std::complex<double>(0.0, 1.0) * mat_nspin_2->get_value(irow, icol);
94-
upper_mat->get_value(2*irow+1, 2*icol) = mat_nspin_1->get_value(irow, icol) - std::complex<double>(0.0, 1.0) * mat_nspin_2->get_value(irow, icol);
131+
switch (is)
132+
{
133+
case 0:
134+
upper_mat->get_value(2*irow, 2*icol) = mat_nspin->get_value(irow, icol);
135+
upper_mat->get_value(2*irow+1, 2*icol+1) = mat_nspin->get_value(irow, icol);
136+
break;
137+
case 3:
138+
upper_mat->get_value(2*irow, 2*icol) += mat_nspin->get_value(irow, icol);
139+
upper_mat->get_value(2*irow+1, 2*icol+1) -= mat_nspin->get_value(irow, icol);
140+
break;
141+
}
95142
}
96143
}
97-
}
98144

99-
// fill the lower triangle matrix
100-
if (iat1 < iat2)
101-
{
102-
auto lower_mat = lower_ap->find_matrix(-R_index);
103-
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
145+
if (PARAM.globalv.domag)
104146
{
105-
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
147+
const hamilt::AtomPair<double>* ap_nspin = nullptr;
148+
switch (is)
106149
{
107-
lower_mat->get_value(icol, irow) = conj(upper_mat->get_value(irow, icol));
150+
case 1:
151+
ap_nspin = hR_tmp->find_pair(iat1, iat2);
152+
break;
153+
case 2:
154+
ap_nspin = hR_tmp->find_pair(iat1, iat2);
155+
break;
156+
}
157+
const auto mat_nspin = ap_nspin->find_matrix(R_index);
158+
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
159+
{
160+
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
161+
{
162+
switch(is)
163+
{
164+
case 1:
165+
upper_mat->get_value(2*irow, 2*icol+1) = mat_nspin->get_value(irow, icol);
166+
upper_mat->get_value(2*irow+1, 2*icol) = mat_nspin->get_value(irow, icol);
167+
break;
168+
case 2:
169+
upper_mat->get_value(2*irow, 2*icol+1) += std::complex<double>(0.0, 1.0) * mat_nspin->get_value(irow, icol);
170+
upper_mat->get_value(2*irow+1, 2*icol) -= std::complex<double>(0.0, 1.0) * mat_nspin->get_value(irow, icol);
171+
break;
172+
}
173+
}
174+
}
175+
}
176+
177+
// fill the lower triangle matrix
178+
if(is == 3){
179+
if (iat1 < iat2)
180+
{
181+
auto lower_mat = lower_ap->find_matrix(-R_index);
182+
for (int irow = 0; irow < upper_mat->get_row_size(); ++irow)
183+
{
184+
for (int icol = 0; icol < upper_mat->get_col_size(); ++icol)
185+
{
186+
lower_mat->get_value(icol, irow) = conj(upper_mat->get_value(irow, icol));
187+
}
188+
}
108189
}
109190
}
110191
}
111192
}
112193
}
194+
113195
}
114-
ModuleBase::timer::tick("Gint", "compose_hr_gint");
115-
}
116-
117-
template <typename T>
118-
void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR)
119-
{
120-
ModuleBase::TITLE("Gint", "transfer_hr_gint_to_hR");
121-
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR");
122-
#ifdef __MPI
123-
int size = 0;
124-
MPI_Comm_size(MPI_COMM_WORLD, &size);
125-
if (size == 1)
126-
{
127-
hR.add(hr_gint);
128-
}
129-
else
130-
{
131-
hamilt::transferSerials2Parallels(hr_gint, &hR);
132-
}
196+
delete[] iat2iwt;
197+
delete pv;
198+
delete hR_tmp;
133199
#else
134-
hR.add(hr_gint);
200+
135201
#endif
136-
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR");
202+
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR_nspin4");
203+
return;
137204
}
138205

206+
139207
// gint_info should not have been a parameter, but it was added to initialize dm_gint_full
140208
// In the future, we might try to remove the gint_info parameter
141209
template<typename T>

source/source_lcao/module_gint/temp_gint/gint_common.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ namespace ModuleGint
66
{
77
// fill the lower triangle matrix with the upper triangle matrix
88
void compose_hr_gint(HContainer<double>& hr_gint);
9-
// for nspin=4 case
10-
void compose_hr_gint(const std::vector<HContainer<double>>& hr_gint_part,
11-
HContainer<std::complex<double>>& hr_gint_full);
9+
1210

1311
template <typename T>
1412
void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR);
13+
// for nspin=4 case
14+
void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
15+
HContainer<std::complex<double>>& hR,
16+
const GintInfo& gint_info);
1517

1618
template<typename T>
1719
void transfer_dm_2d_to_gint(

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ void Gint_vl_metagga_nspin4::cal_gint()
1414
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1515
init_hr_gint_();
1616
cal_hr_gint_();
17-
compose_hr_gint(hr_gint_part_, hr_gint_full_);
18-
transfer_hr_gint_to_hR(hr_gint_full_, *hR_);
17+
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
1918
ModuleBase::timer::tick("Gint", "cal_gint_vl");
2019
}
2120

@@ -26,8 +25,6 @@ void Gint_vl_metagga_nspin4::init_hr_gint_()
2625
{
2726
hr_gint_part_[i] = gint_info_->get_hr<double>();
2827
}
29-
const int npol = 2;
30-
hr_gint_full_ = gint_info_->get_hr<std::complex<double>>(npol);
3128
}
3229

3330
void Gint_vl_metagga_nspin4::cal_hr_gint_()

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class Gint_vl_metagga_nspin4 : public Gint
3737
const int nspin_ = 4;
3838

3939
std::vector<HContainer<double>> hr_gint_part_;
40-
HContainer<std::complex<double>> hr_gint_full_;
4140
};
4241

4342
}

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4_gpu.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ void Gint_vl_metagga_nspin4_gpu::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
compose_hr_gint(hr_gint_part_, hr_gint_full_);
17-
transfer_hr_gint_to_hR(hr_gint_full_, *hR_);
16+
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
1817
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1918
}
2019

@@ -25,8 +24,6 @@ void Gint_vl_metagga_nspin4_gpu::init_hr_gint_()
2524
{
2625
hr_gint_part_[i] = gint_info_->get_hr<double>();
2726
}
28-
const int npol = 2;
29-
hr_gint_full_ = gint_info_->get_hr<std::complex<double>>(npol);
3027
}
3128

3229
void Gint_vl_metagga_nspin4_gpu::transfer_cpu_to_gpu_()

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4_gpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Gint_vl_metagga_nspin4_gpu : public Gint
4242
const int nspin_ = 4;
4343

4444
std::vector<HContainer<double>> hr_gint_part_;
45-
HContainer<std::complex<double>> hr_gint_full_;
45+
//HContainer<std::complex<double>> hr_gint_full_;
4646

4747
std::vector<CudaMemWrapper<double>> vr_eff_d_;
4848
std::vector<CudaMemWrapper<double>> vofk_d_;

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ void Gint_vl_nspin4::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
compose_hr_gint(hr_gint_part_, hr_gint_full_);
17-
transfer_hr_gint_to_hR(hr_gint_full_, *hR_);
16+
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
1817
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1918
}
2019

@@ -25,8 +24,6 @@ void Gint_vl_nspin4::init_hr_gint_()
2524
{
2625
hr_gint_part_[i] = gint_info_->get_hr<double>();
2726
}
28-
const int npol = 2;
29-
hr_gint_full_ = gint_info_->get_hr<std::complex<double>>(npol);
3027
}
3128

3229
void Gint_vl_nspin4::cal_hr_gint_()

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class Gint_vl_nspin4 : public Gint
3939
const int nspin_ = 4;
4040

4141
std::vector<HContainer<double>> hr_gint_part_;
42-
HContainer<std::complex<double>> hr_gint_full_;
4342
};
4443

4544
} // namespace ModuleGint

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4_gpu.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ void Gint_vl_nspin4_gpu::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
compose_hr_gint(hr_gint_part_, hr_gint_full_);
17-
transfer_hr_gint_to_hR(hr_gint_full_, *hR_);
16+
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
1817
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1918
}
2019

@@ -25,8 +24,6 @@ void Gint_vl_nspin4_gpu::init_hr_gint_()
2524
{
2625
hr_gint_part_[i] = gint_info_->get_hr<double>();
2726
}
28-
const int npol = 2;
29-
hr_gint_full_ = gint_info_->get_hr<std::complex<double>>(npol);
3027
}
3128

3229
void Gint_vl_nspin4_gpu::transfer_cpu_to_gpu_()

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4_gpu.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class Gint_vl_nspin4_gpu : public Gint
4444
const int nspin_ = 4;
4545

4646
std::vector<HContainer<double>> hr_gint_part_;
47-
HContainer<std::complex<double>> hr_gint_full_;
4847

4948
std::vector<CudaMemWrapper<double>> vr_eff_d_;
5049
std::vector<CudaMemWrapper<double>> hr_gint_part_d_;

0 commit comments

Comments
 (0)