@@ -33,7 +33,7 @@ Gint::~Gint() {
33
33
delete this ->hRGint_tmp [is];
34
34
}
35
35
#ifdef __MPI
36
- delete this ->DMRGint_full ;
36
+ delete this ->DM2D_tmp ;
37
37
#endif
38
38
}
39
39
@@ -170,16 +170,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
170
170
this ->DMRGint [is] = new hamilt::HContainer<double >(ucell_in.nat );
171
171
this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172
172
}
173
- #ifdef __MPI
174
- if (this ->DM2D_tmp != nullptr ) {
175
- delete this ->DM2D_tmp ;
176
- }
177
- this ->DM2D_tmp = new hamilt::HContainer<double >(ucell_in.nat );
178
- if (this ->DMRGint_full != nullptr ) {
179
- delete this ->DMRGint_full ;
180
- }
181
- this ->DMRGint_full = new hamilt::HContainer<double >(ucell_in.nat );
182
- #endif
173
+
174
+
183
175
}
184
176
185
177
if (PARAM.globalv .gamma_only_local && nspin != 4 ) {
@@ -205,30 +197,19 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
205
197
this ->hRGintCd ->allocate (nullptr , true );
206
198
ModuleBase::Memory::record (" Gint::hRGintCd" ,
207
199
this ->hRGintCd ->get_memory_size ());
208
- this ->DM2D_tmp ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
209
- this ->DM2D_tmp ->allocate (nullptr , true );
210
- ModuleBase::Memory::record (" Gint::DM2D_tmp" ,
211
- this ->DM2D_tmp ->get_memory_size ());
212
200
for (int is = 0 ; is < nspin; is++) {
213
201
this ->hRGint_tmp [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
214
202
this ->DMRGint [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
215
-
216
203
this ->hRGint_tmp [is]->allocate (nullptr , true );
217
204
this ->DMRGint [is]->allocate (nullptr , true );
218
205
}
219
206
ModuleBase::Memory::record (" Gint::hRGint_tmp" ,
220
207
this ->hRGint_tmp [0 ]->get_memory_size ()*nspin);
221
208
ModuleBase::Memory::record (" Gint::DMRGint" ,
222
209
this ->DMRGint [0 ]->get_memory_size ()*nspin);
223
- #ifdef __MPI
224
- this ->DMRGint_full ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
225
- this ->DMRGint_full ->allocate (nullptr , true );
226
- ModuleBase::Memory::record (" Gint::DMRGint_full" ,
227
- this ->DMRGint_full ->get_memory_size ());
228
- #endif
210
+ // GlobalV::ofs_running << "Gint::DMRGint: " << float(this->DMRGint[0]->get_memory_size())/1024/1024 *nspin<< std::endl;
211
+ // GlobalV::ofs_running << "Gint::hRGint_tmp: " << float(this->hRGint_tmp[0]->get_memory_size())/1024/1024 *nspin<< std::endl;
229
212
}
230
- // std::cout<<" DMRGint " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
231
- // std::cout<<" DMRGint_full " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
232
213
}
233
214
234
215
void Gint::reset_DMRGint (const int & nspin)
@@ -242,11 +223,6 @@ void Gint::reset_DMRGint(const int& nspin)
242
223
if (nspin == 4 )
243
224
{
244
225
for (auto & d : this ->DMRGint ) { d->allocate (nullptr , false ); }
245
- #ifdef __MPI
246
- delete this ->DMRGint_full ;
247
- this ->DMRGint_full = new hamilt::HContainer<double >(*this ->hRGint );
248
- this ->DMRGint_full ->allocate (nullptr , false );
249
- #endif
250
226
}
251
227
}
252
228
}
@@ -273,68 +249,69 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
273
249
}
274
250
} else // NSPIN=4 case
275
251
{
276
-
277
- for (int is = 0 ; is < 4 ; is++) {
278
252
#ifdef __MPI
279
- std::vector<double *> tmp_pointer (4 , nullptr );
253
+ int mg = DM2D[0 ]->get_paraV ()->get_global_row_size ()/2 ;
254
+ int ng = DM2D[0 ]->get_paraV ()->get_global_col_size ()/2 ;
255
+ int nb = DM2D[0 ]->get_paraV ()->get_block_size ();
256
+ int blacs_ctxt = DM2D[0 ]->get_paraV ()->blacs_ctxt ;
257
+ int *iat2iwt = new int [ucell->nat ];
258
+ for (int iat = 0 ; iat < ucell->nat ; iat++) {
259
+ iat2iwt[iat] = ucell->get_iat2iwt ()[iat]/2 ;
260
+ }
261
+ Parallel_Orbitals *pv = new Parallel_Orbitals ();
262
+ pv->set (mg, ng, nb, blacs_ctxt);
263
+ pv->set_atomic_trace (iat2iwt, ucell->nat , mg);
264
+ auto ijr_info = DM2D[0 ]->get_ijr_info ();
265
+ this -> DM2D_tmp = new hamilt::HContainer<double >(pv, nullptr , &ijr_info);
266
+ ModuleBase::Memory::record (" Gint::DM2D_tmp" , this ->DM2D_tmp ->get_memory_size ());
267
+ // GlobalV::ofs_running << "Gint::DM2D_tmp: " << float(this -> DM2D_tmp->get_memory_size())/1024/1024 << std::endl;
268
+ for (int is = 0 ; is < 4 ; is++){
280
269
for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
281
270
auto & ap = DM2D[0 ]->get_atom_pair (iap);
282
271
int iat1 = ap.get_atom_i ();
283
272
int iat2 = ap.get_atom_j ();
284
273
for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
285
274
const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
286
- for (int is = 0 ; is < 4 ; is++) {
287
- tmp_pointer[is] = this -> DM2D_tmp->find_matrix (iat1, iat2, r_index)->get_pointer ();
288
- }
275
+ double * tmp_pointer = this -> DM2D_tmp -> find_matrix (iat1, iat2, r_index)->get_pointer ();
289
276
double * data_full = ap.get_pointer (ir);
290
277
for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
291
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
292
- *(tmp_pointer[0 ])++ = data_full[icol];
293
- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
278
+ switch (is) {// todo: It can be written more compactly
279
+ case 0 :
280
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
281
+ *(tmp_pointer)++ = data_full[icol];
282
+ }
283
+ data_full += ap.get_col_size () * 2 ;
284
+ break ;
285
+ case 1 :
286
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
287
+ *(tmp_pointer)++ = data_full[icol + 1 ];
288
+ }
289
+ data_full += ap.get_col_size () * 2 ;
290
+ break ;
291
+ case 2 :
292
+ data_full += ap.get_col_size ();
293
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
294
+ *(tmp_pointer)++ = data_full[icol];
295
+ }
296
+ data_full += ap.get_col_size ();
297
+ break ;
298
+ case 3 :
299
+ data_full += ap.get_col_size ();
300
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
301
+ *(tmp_pointer)++ = data_full[icol + 1 ];
302
+ }
303
+ data_full += ap.get_col_size ();
304
+ break ;
294
305
}
295
- data_full += ap.get_col_size ();
296
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297
- *(tmp_pointer[2 ])++ = data_full[icol];
298
- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
299
- }
300
- data_full += ap.get_col_size ();
301
306
}
302
307
}
303
308
}
304
309
hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
305
-
306
-
310
+ }
307
311
#else
308
- this ->DMRGint [is] = DM2D[0 ];
312
+ // wait to write non-MPI version
313
+
309
314
#endif
310
- }
311
- // std::vector<double*> tmp_pointer(4, nullptr);
312
- // for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
313
- // auto& ap = this->DMRGint_full->get_atom_pair(iap);
314
- // int iat1 = ap.get_atom_i();
315
- // int iat2 = ap.get_atom_j();
316
- // for (int ir = 0; ir < ap.get_R_size(); ++ir) {
317
- // const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
318
- // for (int is = 0; is < 4; is++) {
319
- // tmp_pointer[is] = this->DMRGint[is]
320
- // ->find_matrix(iat1, iat2, r_index)
321
- // ->get_pointer();
322
- // }
323
- // double* data_full = ap.get_pointer(ir);
324
- // for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
325
- // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
326
- // *(tmp_pointer[0])++ = data_full[icol];
327
- // *(tmp_pointer[1])++ = data_full[icol + 1];
328
- // }
329
- // data_full += ap.get_col_size();
330
- // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
331
- // *(tmp_pointer[2])++ = data_full[icol];
332
- // *(tmp_pointer[3])++ = data_full[icol + 1];
333
- // }
334
- // data_full += ap.get_col_size();
335
- // }
336
- // }
337
- // }
338
- }
315
+ }
339
316
ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
340
317
}
0 commit comments