@@ -33,7 +33,7 @@ Gint::~Gint() {
3333 delete this ->hRGint_tmp [is];
3434 }
3535#ifdef __MPI
36- delete this ->DMRGint_full ;
36+ delete this ->DM2D_tmp ;
3737#endif
3838}
3939
@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171 this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172172 }
173173#ifdef __MPI
174- if (this ->DMRGint_full != nullptr ) {
175- delete this ->DMRGint_full ;
174+ if (this ->DM2D_tmp != nullptr ) {
175+ delete this ->DM2D_tmp ;
176176 }
177- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
178177#endif
179178 }
180179
@@ -211,12 +210,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
211210 this ->hRGint_tmp [0 ]->get_memory_size ()*nspin);
212211 ModuleBase::Memory::record (" Gint::DMRGint" ,
213212 this ->DMRGint [0 ]->get_memory_size ()*nspin);
214- #ifdef __MPI
215- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
216- this ->DMRGint_full ->allocate (nullptr , true );
217- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
218- this ->DMRGint_full ->get_memory_size ());
219- #endif
220213 }
221214}
222215
@@ -232,10 +225,8 @@ void Gint::reset_DMRGint(const int& nspin)
232225 {
233226 for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
234227#ifdef __MPI
235- delete this ->DMRGint_full ;
236- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
237- this ->DMRGint_full ->allocate (nullptr , false );
238- #endif
228+ delete this ->DM2D_tmp ;
229+ #endif
239230 }
240231 }
241232}
@@ -263,37 +254,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
263254 } else // NSPIN=4 case
264255 {
265256#ifdef __MPI
266- hamilt::transferParallels2Serials (*DM2D[0 ], this ->DMRGint_full );
267- #else
268- this ->DMRGint_full = DM2D[0 ];
269- #endif
270- std::vector<double *> tmp_pointer (4 , nullptr );
271- for (int iap = 0 ; iap < this ->DMRGint_full ->size_atom_pairs (); ++iap) {
272- auto & ap = this ->DMRGint_full ->get_atom_pair (iap);
273- int iat1 = ap.get_atom_i ();
274- int iat2 = ap.get_atom_j ();
275- for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
276- const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
277- for (int is = 0 ; is < 4 ; is++) {
278- tmp_pointer[is] = this ->DMRGint [is]
279- ->find_matrix (iat1, iat2, r_index)
280- ->get_pointer ();
281- }
282- double * data_full = ap.get_pointer (ir);
283- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
284- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
285- *(tmp_pointer[0 ])++ = data_full[icol];
286- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
287- }
288- data_full += ap.get_col_size ();
289- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290- *(tmp_pointer[2 ])++ = data_full[icol];
291- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
257+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
258+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
259+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
260+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
261+ int *iat2iwt = new int [ucell->nat ];
262+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
263+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
264+ }
265+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
266+ pv->set (mg, ng, nb, blacs_ctxt);
267+ pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
268+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
269+ this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
270+ ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
271+ for (int is = 0 ; is < 4 ; is++){
272+ for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
273+ auto & ap = DM2D[0 ]->get_atom_pair (iap);
274+ int iat1 = ap.get_atom_i ();
275+ int iat2 = ap.get_atom_j ();
276+ for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
277+ const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
278+ double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
279+ double * data_full = ap.get_pointer (ir);
280+ for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
281+ switch (is) {// todo: It can be written more compactly
282+ case 0 :
283+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
284+ *(tmp_pointer)++ = data_full[icol];
285+ }
286+ data_full += ap.get_col_size () * 2 ;
287+ break ;
288+ case 1 :
289+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290+ *(tmp_pointer)++ = data_full[icol + 1 ];
291+ }
292+ data_full += ap.get_col_size () * 2 ;
293+ break ;
294+ case 2 :
295+ data_full += ap.get_col_size ();
296+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297+ *(tmp_pointer)++ = data_full[icol];
298+ }
299+ data_full += ap.get_col_size ();
300+ break ;
301+ case 3 :
302+ data_full += ap.get_col_size ();
303+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
304+ *(tmp_pointer)++ = data_full[icol + 1 ];
305+ }
306+ data_full += ap.get_col_size ();
307+ break ;
308+ }
292309 }
293- data_full += ap.get_col_size ();
294310 }
295311 }
312+ hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
296313 }
314+ // delete iat2iwt;
315+ // delete pv;
316+ // delete this-> DM2D_tmp;
317+ #else
318+ // this->DMRGint_full = DM2D[0];
319+ #endif
297320 }
298321 ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
299322}
0 commit comments