Skip to content

Commit 7035e85

Browse files
authored
Refactor: Optimize memory management in RT-TDDFT to enable large-scale GPU calculations (#6995)
* Optimize memory management of RT-TDDFT to enable larger scale calculation * Fixed a bug where calculation results were incorrect due to not allocating H matrix on device
1 parent d999af7 commit 7035e85

9 files changed

Lines changed: 421 additions & 688 deletions

File tree

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "esolver_ks_lcao_tddft.h"
22

33
//----------------IO-----------------
4-
#include "source_io/module_dipole/dipole_io.h"
54
#include "source_io/module_ctrl/ctrl_output_td.h"
65
#include "source_io/module_current/td_current_io.h"
6+
#include "source_io/module_dipole/dipole_io.h"
77
#include "source_io/module_output/output_log.h"
88
#include "source_io/module_wf/read_wfc_nao.h"
99
//------LCAO HSolver ElecState-------
@@ -425,32 +425,51 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::store_h_s_psi(UnitCell& ucell,
425425
// Store H and S matrices to Hk_laststep and Sk_laststep
426426
if (use_tensor && use_lapack)
427427
{
428-
// Gather H and S matrices to root process
429428
#ifdef __MPI
430429
int myid = 0;
431430
int num_procs = 1;
432431
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
433432
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
434433

435-
// Global matrix structure
434+
std::complex<double>* h_ptr = nullptr;
435+
std::complex<double>* s_ptr = nullptr;
436+
437+
// Define containers for gathered data (only needed for multi-process)
436438
module_rt::Matrix_g<std::complex<double>> h_mat_g;
437439
module_rt::Matrix_g<std::complex<double>> s_mat_g;
438440

439-
// Collect H matrix
440-
module_rt::gatherMatrix(myid, 0, h_mat, h_mat_g);
441-
BlasConnector::copy(len_HS_ik,
442-
h_mat_g.p.get(),
443-
1,
444-
this->Hk_laststep.template data<std::complex<double>>() + ik * len_HS_ik,
445-
1);
441+
if (num_procs == 1)
442+
{
443+
// Single process: directly point to local data without gather
444+
h_ptr = h_mat.p;
445+
s_ptr = s_mat.p;
446+
}
447+
else
448+
{
449+
// Multiple processes: gather data to the root process (myid == 0) and point to the gathered data
450+
module_rt::gatherMatrix(myid, 0, h_mat, h_mat_g);
451+
module_rt::gatherMatrix(myid, 0, s_mat, s_mat_g);
452+
if (myid == 0)
453+
{
454+
h_ptr = h_mat_g.p.get();
455+
s_ptr = s_mat_g.p.get();
456+
}
457+
}
446458

447-
// Collect S matrix
448-
module_rt::gatherMatrix(myid, 0, s_mat, s_mat_g);
449-
BlasConnector::copy(len_HS_ik,
450-
s_mat_g.p.get(),
451-
1,
452-
this->Sk_laststep.template data<std::complex<double>>() + ik * len_HS_ik,
453-
1);
459+
// Only the root process (myid == 0) performs the copy
460+
if (myid == 0 && h_ptr != nullptr && s_ptr != nullptr)
461+
{
462+
BlasConnector::copy(len_HS_ik,
463+
h_ptr,
464+
1,
465+
this->Hk_laststep.template data<std::complex<double>>() + ik * len_HS_ik,
466+
1);
467+
BlasConnector::copy(len_HS_ik,
468+
s_ptr,
469+
1,
470+
this->Sk_laststep.template data<std::complex<double>>() + ik * len_HS_ik,
471+
1);
472+
}
454473
#endif
455474
}
456475
else

0 commit comments

Comments
 (0)