@@ -33,7 +33,7 @@ Gint::~Gint() {
33
33
delete this ->hRGint_tmp [is];
34
34
}
35
35
#ifdef __MPI
36
- delete this ->DMRGint_full ;
36
+ delete this ->DM2D_tmp ;
37
37
#endif
38
38
}
39
39
@@ -171,10 +171,9 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171
171
this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172
172
}
173
173
#ifdef __MPI
174
- if (this ->DMRGint_full != nullptr ) {
175
- delete this ->DMRGint_full ;
174
+ if (this ->DM2D_tmp != nullptr ) {
175
+ delete this ->DM2D_tmp ;
176
176
}
177
- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
178
177
#endif
179
178
}
180
179
@@ -211,12 +210,6 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
211
210
this ->hRGint_tmp [0 ]->get_memory_size ()*nspin);
212
211
ModuleBase::Memory::record (" Gint::DMRGint" ,
213
212
this ->DMRGint [0 ]->get_memory_size ()*nspin);
214
- #ifdef __MPI
215
- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
216
- this ->DMRGint_full ->allocate (nullptr , true );
217
- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
218
- this ->DMRGint_full ->get_memory_size ());
219
- #endif
220
213
}
221
214
}
222
215
@@ -232,10 +225,8 @@ void Gint::reset_DMRGint(const int& nspin)
232
225
{
233
226
for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
234
227
#ifdef __MPI
235
- delete this ->DMRGint_full ;
236
- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
237
- this ->DMRGint_full ->allocate (nullptr , false );
238
- #endif
228
+ delete this ->DM2D_tmp ;
229
+ #endif
239
230
}
240
231
}
241
232
}
@@ -263,37 +254,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
263
254
} else // NSPIN=4 case
264
255
{
265
256
#ifdef __MPI
266
- hamilt::transferParallels2Serials (*DM2D[0 ], this ->DMRGint_full );
267
- #else
268
- this ->DMRGint_full = DM2D[0 ];
269
- #endif
270
- std::vector<double *> tmp_pointer (4 , nullptr );
271
- for (int iap = 0 ; iap < this ->DMRGint_full ->size_atom_pairs (); ++iap) {
272
- auto & ap = this ->DMRGint_full ->get_atom_pair (iap);
273
- int iat1 = ap.get_atom_i ();
274
- int iat2 = ap.get_atom_j ();
275
- for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
276
- const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
277
- for (int is = 0 ; is < 4 ; is++) {
278
- tmp_pointer[is] = this ->DMRGint [is]
279
- ->find_matrix (iat1, iat2, r_index)
280
- ->get_pointer ();
281
- }
282
- double * data_full = ap.get_pointer (ir);
283
- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
284
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
285
- *(tmp_pointer[0 ])++ = data_full[icol];
286
- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
287
- }
288
- data_full += ap.get_col_size ();
289
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290
- *(tmp_pointer[2 ])++ = data_full[icol];
291
- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
257
+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
258
+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
259
+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ()/2 ;
260
+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
261
+ int *iat2iwt = new int [ucell->nat ];
262
+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
263
+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
264
+ }
265
+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
266
+ pv->set (mg, ng, nb, blacs_ctxt);
267
+ pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
268
+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
269
+ this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
270
+ ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
271
+ for (int is = 0 ; is < 4 ; is++){
272
+ for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
273
+ auto & ap = DM2D[0 ]->get_atom_pair (iap);
274
+ int iat1 = ap.get_atom_i ();
275
+ int iat2 = ap.get_atom_j ();
276
+ for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
277
+ const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
278
+ double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
279
+ double * data_full = ap.get_pointer (ir);
280
+ for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
281
+ switch (is) {// todo: It can be written more compactly
282
+ case 0 :
283
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
284
+ *(tmp_pointer)++ = data_full[icol];
285
+ }
286
+ data_full += ap.get_col_size () * 2 ;
287
+ break ;
288
+ case 1 :
289
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290
+ *(tmp_pointer)++ = data_full[icol + 1 ];
291
+ }
292
+ data_full += ap.get_col_size () * 2 ;
293
+ break ;
294
+ case 2 :
295
+ data_full += ap.get_col_size ();
296
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297
+ *(tmp_pointer)++ = data_full[icol];
298
+ }
299
+ data_full += ap.get_col_size ();
300
+ break ;
301
+ case 3 :
302
+ data_full += ap.get_col_size ();
303
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
304
+ *(tmp_pointer)++ = data_full[icol + 1 ];
305
+ }
306
+ data_full += ap.get_col_size ();
307
+ break ;
308
+ }
292
309
}
293
- data_full += ap.get_col_size ();
294
310
}
295
311
}
312
+ hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
296
313
}
314
+ // delete iat2iwt;
315
+ // delete pv;
316
+ // delete this-> DM2D_tmp;
317
+ #else
318
+ // this->DMRGint_full = DM2D[0];
319
+ #endif
297
320
}
298
321
ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
299
322
}
0 commit comments