@@ -33,7 +33,7 @@ Gint::~Gint() {
3333 delete this ->hRGint_tmp [is];
3434 }
3535#ifdef __MPI
36- delete this ->DMRGint_full ;
36+ delete this ->dm2d_tmp ;
3737#endif
3838}
3939
@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171 this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172172 }
173173#ifdef __MPI
174- if (this ->DMRGint_full != nullptr ) {
175- delete this ->DMRGint_full ;
174+ if (this ->dm2d_tmp != nullptr ) {
175+ delete this ->dm2d_tmp ;
176176 }
177- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
178177#endif
179178 }
180179
@@ -210,12 +209,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
210209 ModuleBase::Memory::record (" Gint::DMRGint" ,
211210 this ->DMRGint [0 ]->get_memory_size ()
212211 * this ->DMRGint .size ()*nspin);
213- #ifdef __MPI
214- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
215- this ->DMRGint_full ->allocate (nullptr , true );
216- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
217- this ->DMRGint_full ->get_memory_size ());
218- #endif
219212 }
220213}
221214
@@ -231,9 +224,7 @@ void Gint::reset_DMRGint(const int& nspin)
231224 {
232225 for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
233226#ifdef __MPI
234- delete this ->DMRGint_full ;
235- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
236- this ->DMRGint_full ->allocate (nullptr , false );
227+ delete this ->dm2d_tmp ;
237228#endif
238229 }
239230 }
@@ -262,37 +253,46 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262253 } else // NSPIN=4 case
263254 {
264255#ifdef __MPI
265- hamilt::transferParallels2Serials (*DM2D[0 ], this ->DMRGint_full );
266- #else
267- this ->DMRGint_full = DM2D[0 ];
268- #endif
269- std::vector<double *> tmp_pointer (4 , nullptr );
270- for (int iap = 0 ; iap < this ->DMRGint_full ->size_atom_pairs (); ++iap) {
271- auto & ap = this ->DMRGint_full ->get_atom_pair (iap);
272- int iat1 = ap.get_atom_i ();
273- int iat2 = ap.get_atom_j ();
274- for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
275- const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
276- for (int is = 0 ; is < 4 ; is++) {
277- tmp_pointer[is] = this ->DMRGint [is]
278- ->find_matrix (iat1, iat2, r_index)
279- ->get_pointer ();
280- }
281- double * data_full = ap.get_pointer (ir);
282- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
283- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
284- *(tmp_pointer[0 ])++ = data_full[icol];
285- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
286- }
287- data_full += ap.get_col_size ();
288- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
289- *(tmp_pointer[2 ])++ = data_full[icol];
290- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
256+ // is=0:↑↑, 1:↑↓, 2:↓↑, 3:↓↓
257+ const int row_set[4 ] = {0 , 0 , 1 , 1 };
258+ const int col_set[4 ] = {0 , 1 , 0 , 1 };
259+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
260+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
261+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
262+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
263+ std::vector<int > iat2iwt (ucell->nat );
264+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
265+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
266+ }
267+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
268+ pv->set (mg, ng, nb, blacs_ctxt);
269+ pv->set_atomic_trace (iat2iwt.data (), ucell->nat , mg);
270+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
271+ this -> dm2d_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
272+ ModuleBase::Memory::record (" Gint::dm2d_tmp" , this ->dm2d_tmp ->get_memory_size ());
273+ for (int is = 0 ; is < 4 ; is++){
274+ for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
275+ auto & ap = DM2D[0 ]->get_atom_pair (iap);
276+ int iat1 = ap.get_atom_i ();
277+ int iat2 = ap.get_atom_j ();
278+ for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
279+ const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
280+ double * matrix_out = this -> dm2d_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
281+ double * matrix_in = ap.get_pointer (ir);
282+ for (int irow = 0 ; irow < ap.get_row_size ()/2 ; irow ++) {
283+ for (int icol = 0 ; icol < ap.get_col_size ()/2 ; icol++){
284+ int index_i = irow* ap.get_col_size ()/2 + icol;
285+ int index_j = (irow*2 +row_set[is]) * ap.get_col_size () + icol*2 +col_set[is];
286+ matrix_out[index_i] = matrix_in[index_j];
287+ }
291288 }
292- data_full += ap.get_col_size ();
293289 }
294290 }
291+ hamilt::transferParallels2Serials ( *(this ->dm2d_tmp ), this ->DMRGint [is]);
295292 }
293+ #else
294+ // this->DMRGint_full = DM2D[0];
295+ #endif
296296 }
297297 ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
298298}
0 commit comments