@@ -33,7 +33,7 @@ Gint::~Gint() {
3333 delete this ->hRGint_tmp [is];
3434 }
3535#ifdef __MPI
36- delete this ->DMRGint_full ;
36+ delete this ->DM2D_tmp ;
3737#endif
3838}
3939
@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171171 this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172172 }
173173#ifdef __MPI
174- if (this ->DMRGint_full != nullptr ) {
175- delete this ->DMRGint_full ;
174+ if (this ->DM2D_tmp != nullptr ) {
175+ delete this ->DM2D_tmp ;
176176 }
177- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
178177#endif
179178 }
180179
@@ -199,6 +198,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
199198 } else {
200199 this ->hRGintCd ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
201200 this ->hRGintCd ->allocate (nullptr , true );
201+ ModuleBase::Memory::record (" Gint::hRGintCd" ,
202+ this ->hRGintCd ->get_memory_size ());
202203 for (int is = 0 ; is < nspin; is++) {
203204 this ->hRGint_tmp [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
204205 this ->DMRGint [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
@@ -208,14 +209,7 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
208209 ModuleBase::Memory::record (" Gint::hRGint_tmp" ,
209210 this ->hRGint_tmp [0 ]->get_memory_size ()*nspin);
210211 ModuleBase::Memory::record (" Gint::DMRGint" ,
211- this ->DMRGint [0 ]->get_memory_size ()
212- * this ->DMRGint .size ()*nspin);
213- #ifdef __MPI
214- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
215- this ->DMRGint_full ->allocate (nullptr , true );
216- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
217- this ->DMRGint_full ->get_memory_size ());
218- #endif
212+ this ->DMRGint [0 ]->get_memory_size ()*nspin);
219213 }
220214}
221215
@@ -231,10 +225,8 @@ void Gint::reset_DMRGint(const int& nspin)
231225 {
232226 for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
233227#ifdef __MPI
234- delete this ->DMRGint_full ;
235- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
236- this ->DMRGint_full ->allocate (nullptr , false );
237- #endif
228+ delete this ->DM2D_tmp ;
229+ #endif
238230 }
239231 }
240232}
@@ -262,37 +254,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262254 } else // NSPIN=4 case
263255 {
264256#ifdef __MPI
265- hamilt::transferParallels2Serials (*DM2D[0 ], this ->DMRGint_full );
266- #else
267- this ->DMRGint_full = DM2D[0 ];
268- #endif
269- std::vector<double *> tmp_pointer (4 , nullptr );
270- for (int iap = 0 ; iap < this ->DMRGint_full ->size_atom_pairs (); ++iap) {
271- auto & ap = this ->DMRGint_full ->get_atom_pair (iap);
272- int iat1 = ap.get_atom_i ();
273- int iat2 = ap.get_atom_j ();
274- for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
275- const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
276- for (int is = 0 ; is < 4 ; is++) {
277- tmp_pointer[is] = this ->DMRGint [is]
278- ->find_matrix (iat1, iat2, r_index)
279- ->get_pointer ();
280- }
281- double * data_full = ap.get_pointer (ir);
282- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
283- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
284- *(tmp_pointer[0 ])++ = data_full[icol];
285- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
286- }
287- data_full += ap.get_col_size ();
288- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
289- *(tmp_pointer[2 ])++ = data_full[icol];
290- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
257+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
258+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
259+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
260+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
261+ int *iat2iwt = new int [ucell->nat ];
262+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
263+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
264+ }
265+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
266+ pv->set (mg, ng, nb, blacs_ctxt);
267+ pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
268+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
269+ this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
270+ ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
271+ for (int is = 0 ; is < 4 ; is++){
272+ for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
273+ auto & ap = DM2D[0 ]->get_atom_pair (iap);
274+ int iat1 = ap.get_atom_i ();
275+ int iat2 = ap.get_atom_j ();
276+ for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
277+ const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
278+ double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
279+ double * data_full = ap.get_pointer (ir);
280+ for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
281+ switch (is) {// todo: It can be written more compactly
282+ case 0 :
283+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
284+ *(tmp_pointer)++ = data_full[icol];
285+ }
286+ data_full += ap.get_col_size () * 2 ;
287+ break ;
288+ case 1 :
289+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290+ *(tmp_pointer)++ = data_full[icol + 1 ];
291+ }
292+ data_full += ap.get_col_size () * 2 ;
293+ break ;
294+ case 2 :
295+ data_full += ap.get_col_size ();
296+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297+ *(tmp_pointer)++ = data_full[icol];
298+ }
299+ data_full += ap.get_col_size ();
300+ break ;
301+ case 3 :
302+ data_full += ap.get_col_size ();
303+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
304+ *(tmp_pointer)++ = data_full[icol + 1 ];
305+ }
306+ data_full += ap.get_col_size ();
307+ break ;
308+ }
291309 }
292- data_full += ap.get_col_size ();
293310 }
294311 }
312+ hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
295313 }
314+ // delete iat2iwt;
315+ // delete pv;
316+ // delete this-> DM2D_tmp;
317+ #else
318+ // this->DMRGint_full = DM2D[0];
319+ #endif
296320 }
297321 ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
298322}
0 commit comments