Skip to content

Commit a9cc5c6

Browse files
committed
fix parallel bug and Simplify the computational code
1 parent b5a0f8f commit a9cc5c6

File tree

6 files changed

+76
-111
lines changed

6 files changed

+76
-111
lines changed

source/source_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 71 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,18 @@ void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR)
7070
}
7171

7272
//hRgint_tmp to hR
73-
void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
73+
void merge_hR_n4(std::vector<HContainer<double>>& hRGint_tmp,
7474
HContainer<std::complex<double>>& hR,
7575
const GintInfo& gint_info)
7676
{
77-
ModuleBase::TITLE("Gint", "transfer_hr_gint_to_hR_nspin4");
78-
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR_nspin4");
77+
ModuleBase::TITLE("Gint", "merge_hR_n4");
78+
ModuleBase::timer::tick("Gint", "merge_hR_n4");
7979
#ifdef __MPI
8080
int mg = hR.get_paraV()->get_global_row_size()/2;
8181
int ng = hR.get_paraV()->get_global_col_size()/2;
8282
int nb = hR.get_paraV()->get_block_size()/2;
8383
int blacs_ctxt = hR.get_paraV()->blacs_ctxt;
84+
8485
const UnitCell* ucell = gint_info.get_ucell();
8586
int *iat2iwt = new int[ucell->nat];
8687
for (int iat = 0; iat < ucell->nat; iat++) {
@@ -91,91 +92,49 @@ void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
9192
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
9293
auto ijr_info = hR.get_ijr_info();
9394

94-
hamilt::HContainer<double>* hR_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
95+
auto* hR_tmp = new hamilt::HContainer<std::complex<double>>(pv, nullptr, &ijr_info);
96+
97+
std::vector<int> first = {0, 1, 1, 0};
98+
std::vector<int> second= {3, 2, 2, 3};
99+
std::vector<int> row_set = {0, 0, 1, 1};
100+
std::vector<int> col_set = {0, 1, 0, 1};
101+
std::vector<int> clx_i = {1, 0, 0, -1};
102+
std::vector<int> clx_j = {0, 1, -1, 0};
95103
for (int is = 0; is < 4; is++){
96-
hR_tmp->set_zero();
97-
//std::cout<<"is: "<<is<<std::endl;
98-
hamilt::transferSerials2Parallels( hRGint_tmp[is], hR_tmp);
99-
for (int iap = 0; iap < hR.size_atom_pairs(); iap++)
104+
hamilt::HContainer<std::complex<double>>* hRGint_tmpCd = new hamilt::HContainer<std::complex<double>>(ucell->nat);
105+
ijr_info = hRGint_tmp[0].get_ijr_info();
106+
hRGint_tmpCd->insert_ijrs(&ijr_info, *(ucell));
107+
hRGint_tmpCd->allocate(nullptr, true);
108+
hRGint_tmpCd->set_zero();
109+
for (int iap = 0; iap < hRGint_tmpCd->size_atom_pairs(); iap++)
100110
{
101111
//std::cout<<"iap: "<<iap<<std::endl;
102-
auto* ap = &hR.get_atom_pair(iap);
112+
auto* ap = &hRGint_tmpCd->get_atom_pair(iap);
103113
const int iat1 = ap->get_atom_i();
104114
const int iat2 = ap->get_atom_j();
105-
const hamilt::AtomPair<double>* ap_nspin = nullptr;
106115
if (iat1 <= iat2)
107116
{
108117
hamilt::AtomPair<std::complex<double>>* upper_ap = ap;
109-
hamilt::AtomPair<std::complex<double>>* lower_ap = hR.find_pair(iat2, iat1);
110-
switch (is)
111-
{
112-
case 0:
113-
ap_nspin = hR_tmp->find_pair(iat1, iat2);
114-
break;
115-
case 3:
116-
ap_nspin = hR_tmp->find_pair(iat1, iat2);
117-
break;
118-
}
119-
if(ap_nspin == nullptr) break;
118+
hamilt::AtomPair<std::complex<double>>* lower_ap = hRGint_tmpCd->find_pair(iat2, iat1);
119+
const hamilt::AtomPair<double>* ap_nspin1 = hRGint_tmp[first[is]].find_pair(iat1, iat2);
120+
const hamilt::AtomPair<double>* ap_nspin2 = hRGint_tmp[second[is]].find_pair(iat1, iat2);
120121
for (int ir = 0; ir < upper_ap->get_R_size(); ir++)
121122
{
122123
const auto R_index = upper_ap->get_R_index(ir);
123124
auto upper_mat = upper_ap->find_matrix(R_index);
124-
auto mat_nspin = ap_nspin->find_matrix(R_index);
125-
125+
auto mat_nspin1 = ap_nspin1->find_matrix(R_index);
126+
auto mat_nspin2 = ap_nspin2->find_matrix(R_index);
126127
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
127-
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
128+
for (int irow = 0; irow < mat_nspin1->get_row_size(); ++irow)
128129
{
129-
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
130+
for (int icol = 0; icol < mat_nspin1->get_col_size(); ++icol)
130131
{
131-
switch (is)
132-
{
133-
case 0:
134-
upper_mat->get_value(2*irow, 2*icol) = mat_nspin->get_value(irow, icol);
135-
upper_mat->get_value(2*irow+1, 2*icol+1) = mat_nspin->get_value(irow, icol);
136-
break;
137-
case 3:
138-
upper_mat->get_value(2*irow, 2*icol) += mat_nspin->get_value(irow, icol);
139-
upper_mat->get_value(2*irow+1, 2*icol+1) -= mat_nspin->get_value(irow, icol);
140-
break;
141-
}
142-
}
143-
}
144-
145-
if (PARAM.globalv.domag)
146-
{
147-
const hamilt::AtomPair<double>* ap_nspin = nullptr;
148-
switch (is)
149-
{
150-
case 1:
151-
ap_nspin = hR_tmp->find_pair(iat1, iat2);
152-
break;
153-
case 2:
154-
ap_nspin = hR_tmp->find_pair(iat1, iat2);
155-
break;
156-
}
157-
const auto mat_nspin = ap_nspin->find_matrix(R_index);
158-
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
159-
{
160-
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
161-
{
162-
switch(is)
163-
{
164-
case 1:
165-
upper_mat->get_value(2*irow, 2*icol+1) = mat_nspin->get_value(irow, icol);
166-
upper_mat->get_value(2*irow+1, 2*icol) = mat_nspin->get_value(irow, icol);
167-
break;
168-
case 2:
169-
upper_mat->get_value(2*irow, 2*icol+1) += std::complex<double>(0.0, 1.0) * mat_nspin->get_value(irow, icol);
170-
upper_mat->get_value(2*irow+1, 2*icol) -= std::complex<double>(0.0, 1.0) * mat_nspin->get_value(irow, icol);
171-
break;
172-
}
173-
}
132+
upper_mat->get_value(irow, icol) = mat_nspin1->get_value(irow, icol)
133+
+ std::complex<double>(clx_i[is], clx_j[is]) * mat_nspin2->get_value(irow, icol);
174134
}
175135
}
176-
177-
// fill the lower triangle matrix
178-
if(is == 3){
136+
//fill the lower triangle matrix
137+
if (PARAM.globalv.domag){
179138
if (iat1 < iat2)
180139
{
181140
auto lower_mat = lower_ap->find_matrix(-R_index);
@@ -191,15 +150,41 @@ void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
191150
}
192151
}
193152
}
194-
153+
154+
hR_tmp->set_zero();
155+
hamilt::transferSerials2Parallels( *hRGint_tmpCd, hR_tmp);
156+
for (int iap = 0; iap < hR.size_atom_pairs(); iap++)
157+
{
158+
auto* ap = &hR.get_atom_pair(iap);
159+
const int iat1 = ap->get_atom_i();
160+
const int iat2 = ap->get_atom_j();
161+
auto* ap_nspin = hR_tmp ->find_pair(iat1, iat2);
162+
for (int ir = 0; ir < ap->get_R_size(); ir++)
163+
{
164+
const auto R_index = ap->get_R_index(ir);
165+
auto upper_mat = ap->find_matrix(R_index);
166+
auto mat_nspin = ap_nspin->find_matrix(R_index);
167+
168+
// The row size and the col size of upper_matrix is double that of matrix_nspin_0
169+
for (int irow = 0; irow < mat_nspin->get_row_size(); ++irow)
170+
{
171+
for (int icol = 0; icol < mat_nspin->get_col_size(); ++icol)
172+
{
173+
upper_mat->get_value(2*irow+row_set[is], 2*icol+col_set[is]) =
174+
mat_nspin->get_value(irow, icol);
175+
}
176+
}
177+
}
178+
}
179+
delete hRGint_tmpCd;
195180
}
196181
delete[] iat2iwt;
197-
delete pv;
198-
delete hR_tmp;
199182
#else
200183

201184
#endif
202-
ModuleBase::timer::tick("Gint", "transfer_hr_gint_to_hR_nspin4");
185+
186+
187+
ModuleBase::timer::tick("Gint", "merge_hR_n4");
203188
return;
204189
}
205190

@@ -231,6 +216,9 @@ void transfer_dm_2d_to_gint(
231216
} else // NSPIN=4 case
232217
{
233218
#ifdef __MPI
219+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
220+
const int row_set[4] = {0, 0, 1, 1};
221+
const int col_set[4] = {0, 1, 0, 1};
234222
int mg = dm[0]->get_paraV()->get_global_row_size()/2;
235223
int ng = dm[0]->get_paraV()->get_global_col_size()/2;
236224
int nb = dm[0]->get_paraV()->get_block_size()/2;
@@ -246,43 +234,20 @@ void transfer_dm_2d_to_gint(
246234
auto ijr_info = dm[0]->get_ijr_info();
247235
HContainer<T>* DM2D_tmp = new hamilt::HContainer<T>(pv, nullptr, &ijr_info);
248236
//ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
249-
for (int is = 0; is < 4; is++){
237+
for (int is = 0; is < 4; is++){
250238
for (int iap = 0; iap < dm[0]->size_atom_pairs(); ++iap) {
251239
auto& ap = dm[0]->get_atom_pair(iap);
252240
int iat1 = ap.get_atom_i();
253241
int iat2 = ap.get_atom_j();
254242
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
255243
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
256-
T* tmp_pointer = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
257-
T* data_full = ap.get_pointer(ir);
258-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
259-
switch (is) {//todo: It can be written more compactly
260-
case 0:
261-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
262-
*(tmp_pointer)++ = data_full[icol];
263-
}
264-
data_full += ap.get_col_size() * 2;
265-
break;
266-
case 1:
267-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
268-
*(tmp_pointer)++ = data_full[icol + 1];
269-
}
270-
data_full += ap.get_col_size() * 2;
271-
break;
272-
case 2:
273-
data_full += ap.get_col_size();
274-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
275-
*(tmp_pointer)++ = data_full[icol];
276-
}
277-
data_full += ap.get_col_size();
278-
break;
279-
case 3:
280-
data_full += ap.get_col_size();
281-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
282-
*(tmp_pointer)++ = data_full[icol + 1];
283-
}
284-
data_full += ap.get_col_size();
285-
break;
244+
T* matrix_out = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
245+
T* matrix_in = ap.get_pointer(ir);
246+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
247+
for (int icol = 0; icol < ap.get_col_size()/2; icol ++) {
248+
int index_i = irow* ap.get_col_size()/2 + icol;
249+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
250+
matrix_out[index_i] = matrix_in[index_j];
286251
}
287252
}
288253
}

source/source_lcao/module_gint/temp_gint/gint_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace ModuleGint
1111
template <typename T>
1212
void transfer_hr_gint_to_hR(const HContainer<T>& hr_gint, HContainer<T>& hR);
1313
// for nspin=4 case
14-
void transfer_hr_gint_to_hR_nspin4(std::vector<HContainer<double>>& hRGint_tmp,
14+
void merge_hR_n4(std::vector<HContainer<double>>& hRGint_tmp,
1515
HContainer<std::complex<double>>& hR,
1616
const GintInfo& gint_info);
1717

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void Gint_vl_metagga_nspin4::cal_gint()
1414
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1515
init_hr_gint_();
1616
cal_hr_gint_();
17-
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
17+
merge_hR_n4(hr_gint_part_, *hR_, *gint_info_);
1818
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1919
}
2020

source/source_lcao/module_gint/temp_gint/gint_vl_metagga_nspin4_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ void Gint_vl_metagga_nspin4_gpu::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
16+
merge_hR_n4(hr_gint_part_, *hR_, *gint_info_);
1717
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1818
}
1919

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ void Gint_vl_nspin4::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
16+
merge_hR_n4(hr_gint_part_, *hR_, *gint_info_);
1717
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1818
}
1919

source/source_lcao/module_gint/temp_gint/gint_vl_nspin4_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ void Gint_vl_nspin4_gpu::cal_gint()
1313
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1414
init_hr_gint_();
1515
cal_hr_gint_();
16-
transfer_hr_gint_to_hR_nspin4(hr_gint_part_, *hR_, *gint_info_);
16+
merge_hR_n4(hr_gint_part_, *hR_, *gint_info_);
1717
ModuleBase::timer::tick("Gint", "cal_gint_vl");
1818
}
1919

0 commit comments

Comments
 (0)