@@ -33,7 +33,7 @@ Gint::~Gint() {
3333 delete this ->hRGint_tmp [is];
3434 }
3535#ifdef __MPI
36- delete this ->DMRGint_full ;
36+ delete this ->DM2D_tmp ;
3737#endif
3838}
3939
@@ -170,16 +170,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
170170 this ->DMRGint [is] = new hamilt::HContainer<double >(ucell_in.nat );
171171 this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172172 }
173- #ifdef __MPI
174- if (this ->DM2D_tmp != nullptr ) {
175- delete this ->DM2D_tmp ;
176- }
177- this ->DM2D_tmp = new hamilt::HContainer<double >(ucell_in.nat );
178- if (this ->DMRGint_full != nullptr ) {
179- delete this ->DMRGint_full ;
180- }
181- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
182- #endif
173+
174+
183175 }
184176
185177 if (PARAM.globalv .gamma_only_local && nspin != 4 ) {
@@ -205,30 +197,19 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
205197 this ->hRGintCd ->allocate (nullptr , true );
206198 ModuleBase::Memory::record (" Gint::hRGintCd" ,
207199 this ->hRGintCd ->get_memory_size ());
208- this ->DM2D_tmp ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
209- this ->DM2D_tmp ->allocate (nullptr , true );
210- ModuleBase::Memory::record (" Gint::DM2D_tmp" ,
211- this ->DM2D_tmp ->get_memory_size ());
212200 for (int is = 0 ; is < nspin; is++) {
213201 this ->hRGint_tmp [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
214202 this ->DMRGint [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
215-
216203 this ->hRGint_tmp [is]->allocate (nullptr , true );
217204 this ->DMRGint [is]->allocate (nullptr , true );
218205 }
219206 ModuleBase::Memory::record (" Gint::hRGint_tmp" ,
220207 this ->hRGint_tmp [0 ]->get_memory_size ()*nspin);
221208 ModuleBase::Memory::record (" Gint::DMRGint" ,
222209 this ->DMRGint [0 ]->get_memory_size ()*nspin);
223- #ifdef __MPI
224- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
225- this ->DMRGint_full ->allocate (nullptr , true );
226- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
227- this ->DMRGint_full ->get_memory_size ());
228- #endif
210+ // GlobalV::ofs_running << "Gint::DMRGint: " << float(this->DMRGint[0]->get_memory_size())/1024/1024 *nspin<< std::endl;
211+ // GlobalV::ofs_running << "Gint::hRGint_tmp: " << float(this->hRGint_tmp[0]->get_memory_size())/1024/1024 *nspin<< std::endl;
229212 }
230- // std::cout<<" DMRGint " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
231- // std::cout<<" DMRGint_full " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
232213}
233214
234215void Gint::reset_DMRGint (const int & nspin)
@@ -242,11 +223,6 @@ void Gint::reset_DMRGint(const int& nspin)
242223 if (nspin == 4 )
243224 {
244225 for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
245- #ifdef __MPI
246- delete this ->DMRGint_full ;
247- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
248- this ->DMRGint_full ->allocate (nullptr , false );
249- #endif
250226 }
251227 }
252228}
@@ -273,68 +249,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
273249 }
274250 } else // NSPIN=4 case
275251 {
276-
277- for (int is = 0 ; is < 4 ; is++) {
278252#ifdef __MPI
279- std::vector<double *> tmp_pointer (4 , nullptr );
253+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
254+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
255+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ();
256+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
257+ int *iat2iwt = new int [ucell->nat ];
258+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
259+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
260+ }
261+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
262+ pv->set (mg, ng, nb, blacs_ctxt);
263+ pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
264+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
265+ this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
266+ ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
267+ // GlobalV::ofs_running << "Gint::DM2D_tmp: " << float(this -> DM2D_tmp->get_memory_size())/1024/1024 << std::endl;
268+ for (int is = 0 ; is < 4 ; is++){
280269 for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
281270 auto & ap = DM2D[0 ]->get_atom_pair (iap);
282271 int iat1 = ap.get_atom_i ();
283272 int iat2 = ap.get_atom_j ();
284273 for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
285274 const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
286- for (int is = 0 ; is < 4 ; is++) {
287- tmp_pointer[is] = this -> DM2D_tmp->find_matrix (iat1, iat2, r_index)->get_pointer ();
288- }
275+ double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
289276 double * data_full = ap.get_pointer (ir);
290277 for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
291- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
292- *(tmp_pointer[0 ])++ = data_full[icol];
293- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
278+ switch (is) {// todo: It can be written more compactly
279+ case 0 :
280+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
281+ *(tmp_pointer)++ = data_full[icol];
282+ }
283+ data_full += ap.get_col_size () * 2 ;
284+ break ;
285+ case 1 :
286+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
287+ *(tmp_pointer)++ = data_full[icol + 1 ];
288+ }
289+ data_full += ap.get_col_size () * 2 ;
290+ break ;
291+ case 2 :
292+ data_full += ap.get_col_size ();
293+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
294+ *(tmp_pointer)++ = data_full[icol];
295+ }
296+ data_full += ap.get_col_size ();
297+ break ;
298+ case 3 :
299+ data_full += ap.get_col_size ();
300+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
301+ *(tmp_pointer)++ = data_full[icol + 1 ];
302+ }
303+ data_full += ap.get_col_size ();
304+ break ;
294305 }
295- data_full += ap.get_col_size ();
296- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297- *(tmp_pointer[2 ])++ = data_full[icol];
298- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
299- }
300- data_full += ap.get_col_size ();
301306 }
302307 }
303308 }
304309 hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
305-
306-
310+ }
307311#else
308- this ->DMRGint [is] = DM2D[0 ];
312+ // wait to write non-MPI version
313+
309314#endif
310- }
311- // std::vector<double*> tmp_pointer(4, nullptr);
312- // for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
313- // auto& ap = this->DMRGint_full->get_atom_pair(iap);
314- // int iat1 = ap.get_atom_i();
315- // int iat2 = ap.get_atom_j();
316- // for (int ir = 0; ir < ap.get_R_size(); ++ir) {
317- // const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
318- // for (int is = 0; is < 4; is++) {
319- // tmp_pointer[is] = this->DMRGint[is]
320- // ->find_matrix(iat1, iat2, r_index)
321- // ->get_pointer();
322- // }
323- // double* data_full = ap.get_pointer(ir);
324- // for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
325- // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
326- // *(tmp_pointer[0])++ = data_full[icol];
327- // *(tmp_pointer[1])++ = data_full[icol + 1];
328- // }
329- // data_full += ap.get_col_size();
330- // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
331- // *(tmp_pointer[2])++ = data_full[icol];
332- // *(tmp_pointer[3])++ = data_full[icol + 1];
333- // }
334- // data_full += ap.get_col_size();
335- // }
336- // }
337- // }
338- }
315+ }
339316 ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
340317}
0 commit comments