|
1 | 1 | #include "esolver_ks_lcao_tddft.h" |
2 | 2 |
|
3 | 3 | //----------------IO----------------- |
4 | | -#include "source_io/module_dipole/dipole_io.h" |
5 | 4 | #include "source_io/module_ctrl/ctrl_output_td.h" |
6 | 5 | #include "source_io/module_current/td_current_io.h" |
| 6 | +#include "source_io/module_dipole/dipole_io.h" |
7 | 7 | #include "source_io/module_output/output_log.h" |
8 | 8 | #include "source_io/module_wf/read_wfc_nao.h" |
9 | 9 | //------LCAO HSolver ElecState------- |
@@ -425,32 +425,51 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::store_h_s_psi(UnitCell& ucell, |
425 | 425 | // Store H and S matrices to Hk_laststep and Sk_laststep |
426 | 426 | if (use_tensor && use_lapack) |
427 | 427 | { |
428 | | - // Gather H and S matrices to root process |
429 | 428 | #ifdef __MPI |
430 | 429 | int myid = 0; |
431 | 430 | int num_procs = 1; |
432 | 431 | MPI_Comm_rank(MPI_COMM_WORLD, &myid); |
433 | 432 | MPI_Comm_size(MPI_COMM_WORLD, &num_procs); |
434 | 433 |
|
435 | | - // Global matrix structure |
| 434 | + std::complex<double>* h_ptr = nullptr; |
| 435 | + std::complex<double>* s_ptr = nullptr; |
| 436 | + |
| 437 | + // Define containers for gathered data (only needed for multi-process) |
436 | 438 | module_rt::Matrix_g<std::complex<double>> h_mat_g; |
437 | 439 | module_rt::Matrix_g<std::complex<double>> s_mat_g; |
438 | 440 |
|
439 | | - // Collect H matrix |
440 | | - module_rt::gatherMatrix(myid, 0, h_mat, h_mat_g); |
441 | | - BlasConnector::copy(len_HS_ik, |
442 | | - h_mat_g.p.get(), |
443 | | - 1, |
444 | | - this->Hk_laststep.template data<std::complex<double>>() + ik * len_HS_ik, |
445 | | - 1); |
| 441 | + if (num_procs == 1) |
| 442 | + { |
| 443 | + // Single process: directly point to local data without gather |
| 444 | + h_ptr = h_mat.p; |
| 445 | + s_ptr = s_mat.p; |
| 446 | + } |
| 447 | + else |
| 448 | + { |
| 449 | + // Multiple processes: gather data to the root process (myid == 0) and point to the gathered data |
| 450 | + module_rt::gatherMatrix(myid, 0, h_mat, h_mat_g); |
| 451 | + module_rt::gatherMatrix(myid, 0, s_mat, s_mat_g); |
| 452 | + if (myid == 0) |
| 453 | + { |
| 454 | + h_ptr = h_mat_g.p.get(); |
| 455 | + s_ptr = s_mat_g.p.get(); |
| 456 | + } |
| 457 | + } |
446 | 458 |
|
447 | | - // Collect S matrix |
448 | | - module_rt::gatherMatrix(myid, 0, s_mat, s_mat_g); |
449 | | - BlasConnector::copy(len_HS_ik, |
450 | | - s_mat_g.p.get(), |
451 | | - 1, |
452 | | - this->Sk_laststep.template data<std::complex<double>>() + ik * len_HS_ik, |
453 | | - 1); |
| 459 | + // Only the root process (myid == 0) performs the copy |
| 460 | + if (myid == 0 && h_ptr != nullptr && s_ptr != nullptr) |
| 461 | + { |
| 462 | + BlasConnector::copy(len_HS_ik, |
| 463 | + h_ptr, |
| 464 | + 1, |
| 465 | + this->Hk_laststep.template data<std::complex<double>>() + ik * len_HS_ik, |
| 466 | + 1); |
| 467 | + BlasConnector::copy(len_HS_ik, |
| 468 | + s_ptr, |
| 469 | + 1, |
| 470 | + this->Sk_laststep.template data<std::complex<double>>() + ik * len_HS_ik, |
| 471 | + 1); |
| 472 | + } |
454 | 473 | #endif |
455 | 474 | } |
456 | 475 | else |
|
0 commit comments