@@ -171,6 +171,10 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
171
171
this ->hRGint_tmp [is] = new hamilt::HContainer<double >(ucell_in.nat );
172
172
}
173
173
#ifdef __MPI
174
+ if (this ->DM2D_tmp != nullptr ) {
175
+ delete this ->DM2D_tmp ;
176
+ }
177
+ this ->DM2D_tmp = new hamilt::HContainer<double >(ucell_in.nat );
174
178
if (this ->DMRGint_full != nullptr ) {
175
179
delete this ->DMRGint_full ;
176
180
}
@@ -201,9 +205,14 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
201
205
this ->hRGintCd ->allocate (nullptr , true );
202
206
ModuleBase::Memory::record (" Gint::hRGintCd" ,
203
207
this ->hRGintCd ->get_memory_size ());
208
+ this ->DM2D_tmp ->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in, npol);
209
+ this ->DM2D_tmp ->allocate (nullptr , true );
210
+ ModuleBase::Memory::record (" Gint::DM2D_tmp" ,
211
+ this ->DM2D_tmp ->get_memory_size ());
204
212
for (int is = 0 ; is < nspin; is++) {
205
213
this ->hRGint_tmp [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
206
214
this ->DMRGint [is]->insert_ijrs (this ->gridt ->get_ijr_info (), ucell_in);
215
+
207
216
this ->hRGint_tmp [is]->allocate (nullptr , true );
208
217
this ->DMRGint [is]->allocate (nullptr , true );
209
218
}
@@ -218,6 +227,8 @@ void Gint::initialize_pvpR(const UnitCell& ucell_in, const Grid_Driver* gd, cons
218
227
this ->DMRGint_full ->get_memory_size ());
219
228
#endif
220
229
}
230
+ // std::cout<<" DMRGint " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
231
+ // std::cout<<" DMRGint_full " << DMRGint[0]->get_atom_pair(0,1).get_row_size() << std::endl;
221
232
}
222
233
223
234
void Gint::reset_DMRGint (const int & nspin)
@@ -262,38 +273,68 @@ void Gint::transfer_DM2DtoGrid(std::vector<hamilt::HContainer<double>*> DM2D) {
262
273
}
263
274
} else // NSPIN=4 case
264
275
{
276
+
277
+ for (int is = 0 ; is < 4 ; is++) {
265
278
#ifdef __MPI
266
- hamilt::transferParallels2Serials (*DM2D[0 ], this ->DMRGint_full );
267
- #else
268
- this ->DMRGint_full = DM2D[0 ];
269
- #endif
270
- std::vector<double *> tmp_pointer (4 , nullptr );
271
- for (int iap = 0 ; iap < this ->DMRGint_full ->size_atom_pairs (); ++iap) {
272
- auto & ap = this ->DMRGint_full ->get_atom_pair (iap);
273
- int iat1 = ap.get_atom_i ();
274
- int iat2 = ap.get_atom_j ();
275
- for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
276
- const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
277
- for (int is = 0 ; is < 4 ; is++) {
278
- tmp_pointer[is] = this ->DMRGint [is]
279
- ->find_matrix (iat1, iat2, r_index)
280
- ->get_pointer ();
281
- }
282
- double * data_full = ap.get_pointer (ir);
283
- for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
284
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
285
- *(tmp_pointer[0 ])++ = data_full[icol];
286
- *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
279
+ std::vector<double *> tmp_pointer (4 , nullptr );
280
+ for (int iap = 0 ; iap < DM2D[0 ]->size_atom_pairs (); ++iap) {
281
+ auto & ap = DM2D[0 ]->get_atom_pair (iap);
282
+ int iat1 = ap.get_atom_i ();
283
+ int iat2 = ap.get_atom_j ();
284
+ for (int ir = 0 ; ir < ap.get_R_size (); ++ir) {
285
+ const ModuleBase::Vector3<int > r_index = ap.get_R_index (ir);
286
+ for (int is = 0 ; is < 4 ; is++) {
287
+ tmp_pointer[is] = this -> DM2D_tmp->find_matrix (iat1, iat2, r_index)->get_pointer ();
287
288
}
288
- data_full += ap.get_col_size ();
289
- for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
290
- *(tmp_pointer[2 ])++ = data_full[icol];
291
- *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
289
+ double * data_full = ap.get_pointer (ir);
290
+ for (int irow = 0 ; irow < ap.get_row_size (); irow += 2 ) {
291
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
292
+ *(tmp_pointer[0 ])++ = data_full[icol];
293
+ *(tmp_pointer[1 ])++ = data_full[icol + 1 ];
294
+ }
295
+ data_full += ap.get_col_size ();
296
+ for (int icol = 0 ; icol < ap.get_col_size (); icol += 2 ) {
297
+ *(tmp_pointer[2 ])++ = data_full[icol];
298
+ *(tmp_pointer[3 ])++ = data_full[icol + 1 ];
299
+ }
300
+ data_full += ap.get_col_size ();
292
301
}
293
- data_full += ap.get_col_size ();
294
302
}
295
303
}
304
+ hamilt::transferParallels2Serials ( *(this ->DM2D_tmp ), this ->DMRGint [is]);
305
+
306
+
307
+ #else
308
+ this ->DMRGint [is] = DM2D[0 ];
309
+ #endif
296
310
}
297
- }
311
+ // std::vector<double*> tmp_pointer(4, nullptr);
312
+ // for (int iap = 0; iap < this->DMRGint_full->size_atom_pairs(); ++iap) {
313
+ // auto& ap = this->DMRGint_full->get_atom_pair(iap);
314
+ // int iat1 = ap.get_atom_i();
315
+ // int iat2 = ap.get_atom_j();
316
+ // for (int ir = 0; ir < ap.get_R_size(); ++ir) {
317
+ // const ModuleBase::Vector3<int> r_index = ap.get_R_index(ir);
318
+ // for (int is = 0; is < 4; is++) {
319
+ // tmp_pointer[is] = this->DMRGint[is]
320
+ // ->find_matrix(iat1, iat2, r_index)
321
+ // ->get_pointer();
322
+ // }
323
+ // double* data_full = ap.get_pointer(ir);
324
+ // for (int irow = 0; irow < ap.get_row_size(); irow += 2) {
325
+ // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
326
+ // *(tmp_pointer[0])++ = data_full[icol];
327
+ // *(tmp_pointer[1])++ = data_full[icol + 1];
328
+ // }
329
+ // data_full += ap.get_col_size();
330
+ // for (int icol = 0; icol < ap.get_col_size(); icol += 2) {
331
+ // *(tmp_pointer[2])++ = data_full[icol];
332
+ // *(tmp_pointer[3])++ = data_full[icol + 1];
333
+ // }
334
+ // data_full += ap.get_col_size();
335
+ // }
336
+ // }
337
+ // }
338
+ }
298
339
ModuleBase::timer::tick (" Gint" , " transfer_DMR" );
299
340
}
0 commit comments