@@ -252,6 +252,9 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
252252 } else // NSPIN=4 case
253253 {
254254#ifdef __MPI
255+ // is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
256+ const int row_set[4 ] = {0 , 0 , 1 , 1 };
257+ const int col_set[4 ] = {0 , 1 , 0 , 1 };
255258 int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
256259 int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
257260 int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
@@ -265,6 +268,7 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
265268 pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
266269 auto ijr_info = DM2D[0 ]->get_ijr_info ();
267270 this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
271+ this -> DM2D_tmp->set_zero ();
268272 ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
269273 for (int is = 0 ; is < 4 ; is++){
270274 for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
@@ -273,43 +277,20 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
273277 int iat2 = ap.get_atom_j ();
274278 for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
275279 const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
276- double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
277- double * data_full = ap.get_pointer (ir);
278- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
279- switch (is) {// todo: It can be written more compactly
280- case 0 :
281- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
282- *(tmp_pointer)++ = data_full[icol];
283- }
284- data_full += ap.get_col_size () * 2 ;
285- break ;
286- case 1 :
287- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
288- *(tmp_pointer)++ = data_full[icol + 1 ];
289- }
290- data_full += ap.get_col_size () * 2 ;
291- break ;
292- case 2 :
293- data_full += ap.get_col_size ();
294- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
295- *(tmp_pointer)++ = data_full[icol];
296- }
297- data_full += ap.get_col_size ();
298- break ;
299- case 3 :
300- data_full += ap.get_col_size ();
301- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
302- *(tmp_pointer)++ = data_full[icol + 1 ];
303- }
304- data_full += ap.get_col_size ();
305- break ;
280+ double * matrix_out = DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
281+ double * matrix_in = ap.get_pointer (ir);
282+ for (int irow = 0 ; irow < ap.get_row_size ()/2 ; irow ++) {
283+ for (int icol = 0 ; icol < ap.get_col_size ()/2 ; icol ++) {
284+ int index_i = irow* ap.get_col_size ()/2 + icol;
285+ int index_j = (irow*2 +row_set[is]) * ap.get_col_size () + icol*2 +col_set[is];
286+ matrix_out[index_i] = matrix_in[index_j];
306287 }
307288 }
308289 }
309290 }
310291 hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
311292 }
312- // delete iat2iwt;
293+ // delete iat2iwt [] ;
313294 // delete pv;
314295 // delete this-> DM2D_tmp;
315296#else
0 commit comments