Skip to content

Commit 61360c7

Browse files
committed
Merge branch 'LTS-2' into LTS-3
2 parents d130033 + 3201feb commit 61360c7

File tree

1 file changed

+12
-31
lines changed
  • source/module_hamilt_lcao/module_gint

1 file changed

+12
-31
lines changed

source/module_hamilt_lcao/module_gint/gint.cpp

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
252252
} else // NSPIN=4 case
253253
{
254254
#ifdef __MPI
255+
// is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
256+
const int row_set[4] = {0, 0, 1, 1};
257+
const int col_set[4] = {0, 1, 0, 1};
255258
int mg = DM2D[0]->get_paraV()->get_global_row_size()/2;
256259
int ng = DM2D[0]->get_paraV()->get_global_col_size()/2;
257260
int nb = DM2D[0]->get_paraV()->get_block_size()/2;
@@ -265,6 +268,7 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
265268
pv->set_atomic_trace(iat2iwt, ucell->nat, mg);
266269
auto ijr_info = DM2D[0]->get_ijr_info();
267270
this-> DM2D_tmp = new hamilt::HContainer<double>(pv, nullptr, &ijr_info);
271+
this-> DM2D_tmp->set_zero();
268272
ModuleBase::Memory::record("Gint::DM2D_tmp", this->DM2D_tmp->get_memory_size());
269273
for (int is = 0; is < 4; is++){
270274
for (int iap = 0; iap < DM2D[0]->size_atom_pairs(); ++iap) {
@@ -273,43 +277,20 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
273277
int iat2 = ap.get_atom_j();
274278
for (int ir = 0; ir < ap.get_R_size(); ++ir) {
275279
const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
276-
double* tmp_pointer = this -> DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
277-
double* data_full = ap.get_pointer(ir);
278-
for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
279-
switch (is) {//todo: It can be written more compactly
280-
case 0:
281-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
282-
*(tmp_pointer)++ = data_full[icol];
283-
}
284-
data_full += ap.get_col_size() * 2;
285-
break;
286-
case 1:
287-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
288-
*(tmp_pointer)++ = data_full[icol + 1];
289-
}
290-
data_full += ap.get_col_size() * 2;
291-
break;
292-
case 2:
293-
data_full += ap.get_col_size();
294-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
295-
*(tmp_pointer)++ = data_full[icol];
296-
}
297-
data_full += ap.get_col_size();
298-
break;
299-
case 3:
300-
data_full += ap.get_col_size();
301-
for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
302-
*(tmp_pointer)++ = data_full[icol + 1];
303-
}
304-
data_full += ap.get_col_size();
305-
break;
280+
double* matrix_out = DM2D_tmp -> find_matrix(iat1, iat2, r_index)->get_pointer();
281+
double* matrix_in = ap.get_pointer(ir);
282+
for (int irow = 0; irow < ap.get_row_size()/2; irow ++) {
283+
for (int icol = 0; icol < ap.get_col_size()/2; icol ++) {
284+
int index_i = irow* ap.get_col_size()/2 + icol;
285+
int index_j = (irow*2+row_set[is]) * ap.get_col_size() + icol*2+col_set[is];
286+
matrix_out[index_i] = matrix_in[index_j];
306287
}
307288
}
308289
}
309290
}
310291
hamilt::transferParallels2Serials( *(this->DM2D_tmp), this->DMRGint[is]);
311292
}
312-
// delete iat2iwt;
293+
// delete iat2iwt [];
313294
// delete pv;
314295
// delete this-> DM2D_tmp;
315296
#else

0 commit comments

Comments
 (0)