Skip to content

Commit e6a1983

Browse files
committed
simplify the compute code
1 parent 7d4fe5a commit e6a1983

File tree

2 files changed

+20
-60
lines changed

2 files changed

+20
-60
lines changed

source/source_lcao/module_gint/gint_old.cpp

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
253253
} else // NSPIN=4 case
254254
{
255255
#ifdef __MPI
256+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257+
const int row_set[4] = {0, 0, 1, 1};
258+
const int col_set[4] = {0, 1, 0, 1};
256259
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
257260
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
258261
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
@@ -274,36 +277,13 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
274277
int iat2 = ap.get_atom_j();
275278
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
276279
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
277-
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
278-
double* data_full = ap.get_pointer(ir);
279-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
280-
switch (is) {//todo: It can be written more compactly
281-
case 0:
282-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
283-
*(tmp_pointer)++ = data_full[icol];
284-
}
285-
data_full += ap.get_col_size() * 2;
286-
break;
287-
case 1:
288-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
289-
*(tmp_pointer)++ = data_full[icol + 1];
290-
}
291-
data_full += ap.get_col_size() * 2;
292-
break;
293-
case 2:
294-
data_full += ap.get_col_size();
295-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
296-
*(tmp_pointer)++ = data_full[icol];
297-
}
298-
data_full += ap.get_col_size();
299-
break;
300-
case 3:
301-
data_full += ap.get_col_size();
302-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
303-
*(tmp_pointer)++ = data_full[icol + 1];
304-
}
305-
data_full += ap.get_col_size();
306-
break;
280+
double* matrix_out = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
281+
double* matrix_in = ap.get_pointer(ir);
282+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
283+
for (int icol = 0; icol < ap.get_col_size()/2; icol++){
284+
int index_i = irow* ap.get_col_size()/2 + icol;
285+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
286+
matrix_out[index_i] = matrix_in[index_j];
307287
}
308288
}
309289
}

source/source_lcao/module_gint/temp_gint/gint_common.cpp

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ void transfer_dm_2d_to_gint(
163163
} else // NSPIN=4 case
164164
{
165165
#ifdef __MPI
166+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
167+
const int row_set[4] = {0, 0, 1, 1};
168+
const int col_set[4] = {0, 1, 0, 1};
166169
int mg = dm[0]->get_paraV()->get_global_row_size()/2;
167170
int ng = dm[0]->get_paraV()->get_global_col_size()/2;
168171
int nb = dm[0]->get_paraV()->get_block_size()/2;
@@ -185,36 +188,13 @@ void transfer_dm_2d_to_gint(
185188
int iat2 = ap.get_atom_j();
186189
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
187190
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
188-
T* tmp_pointer = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
189-
T* data_full = ap.get_pointer(ir);
190-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
191-
switch (is) {//todo: It can be written more compactly
192-
case 0:
193-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
194-
*(tmp_pointer)++ = data_full[icol];
195-
}
196-
data_full += ap.get_col_size() * 2;
197-
break;
198-
case 1:
199-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
200-
*(tmp_pointer)++ = data_full[icol + 1];
201-
}
202-
data_full += ap.get_col_size() * 2;
203-
break;
204-
case 2:
205-
data_full += ap.get_col_size();
206-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
207-
*(tmp_pointer)++ = data_full[icol];
208-
}
209-
data_full += ap.get_col_size();
210-
break;
211-
case 3:
212-
data_full += ap.get_col_size();
213-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
214-
*(tmp_pointer)++ = data_full[icol + 1];
215-
}
216-
data_full += ap.get_col_size();
217-
break;
191+
T* matrix_out = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
192+
T* matrix_in = ap.get_pointer(ir);
193+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
194+
for (int icol = 0; icol < ap.get_col_size()/2; icol ++) {
195+
int index_i = irow* ap.get_col_size()/2 + icol;
196+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
197+
matrix_out[index_i] = matrix_in[index_j];
218198
}
219199
}
220200
}

0 commit comments

Comments
 (0)