@@ -253,6 +253,9 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
253253 } else // NSPIN=4 case
254254 {
255255#ifdef __MPI
256+ // is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257+ const int row_set[4 ] = {0 , 0 , 1 , 1 };
258+ const int col_set[4 ] = {0 , 1 , 0 , 1 };
256259 int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
257260 int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
258261 int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
@@ -274,36 +277,13 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
274277 int iat2 = ap.get_atom_j ();
275278 for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
276279 const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
277- double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
278- double * data_full = ap.get_pointer (ir);
279- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
280- switch (is) {// todo: It can be written more compactly
281- case 0 :
282- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
283- *(tmp_pointer)++ = data_full[icol];
284- }
285- data_full += ap.get_col_size () * 2 ;
286- break ;
287- case 1 :
288- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
289- *(tmp_pointer)++ = data_full[icol + 1 ];
290- }
291- data_full += ap.get_col_size () * 2 ;
292- break ;
293- case 2 :
294- data_full += ap.get_col_size ();
295- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
296- *(tmp_pointer)++ = data_full[icol];
297- }
298- data_full += ap.get_col_size ();
299- break ;
300- case 3 :
301- data_full += ap.get_col_size ();
302- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
303- *(tmp_pointer)++ = data_full[icol + 1 ];
304- }
305- data_full += ap.get_col_size ();
306- break ;
280+ double * matrix_out = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
281+ double * matrix_in = ap.get_pointer (ir);
282+ for (int irow = 0 ; irow < ap.get_row_size ()/2 ; irow ++) {
283+ for (int icol = 0 ; icol < ap.get_col_size ()/2 ; icol++){
284+ int index_i = irow* ap.get_col_size ()/2 + icol;
285+ int index_j = (irow*2 +row_set[is]) * ap.get_col_size () + icol*2 +col_set[is];
286+ matrix_out[index_i] = matrix_in[index_j];
307287 }
308288 }
309289 }
0 commit comments