Skip to content

Commit 107024c

Browse files
authored
Solving linear equations to evolve the wave function in RT-TDDFT. (#5925)
* Add files via upload * Add files via upload * Add files via upload * Update input-main.md * Update solve_propagation.cpp
1 parent b774535 commit 107024c

File tree

7 files changed

+187
-12
lines changed

7 files changed

+187
-12
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3477,9 +3477,10 @@ These variables are used to control berry phase and wannier90 interface paramete
34773477
- **Type**: Integer
34783478
- **Description**:
34793479
method of propagator
3480-
- 0: Crank-Nicolson.
3480+
- 0: Crank-Nicolson, based on matrix inversion.
34813481
- 1: 4th Taylor expansions of exponential.
34823482
- 2: enforced time-reversal symmetry (ETRS).
3483+
- 3: Crank-Nicolson, based on solving linear equation.
34833484
- **Default**: 0
34843485

34853486
### td_vext

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ OBJS_LCAO=evolve_elec.o\
571571
td_velocity.o\
572572
td_current.o\
573573
snap_psibeta_half_tddft.o\
574+
solve_propagation.o\
574575
upsi.o\
575576
FORCE_STRESS.o\
576577
FORCE_gamma.o\

source/module_base/scalapack_connector.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ extern "C"
7474
const int *M, const int *N,
7575
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
7676
int *ipiv, int *info);
77+
78+
void pzgesv_(
79+
const int *n, const int *nrhs,
80+
const std::complex<double> *A, const int *ia, const int *ja, const int *desca,
81+
int *ipiv, std::complex<double>* B, const int* ib, const int* jb, const int*descb, const int *info
82+
);
7783

7884
void pdsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
7985
const int* n, double* A, const int* ia, const int* ja, const int*desca, double* B, const int* ib, const int* jb, const int*descb,
@@ -240,6 +246,15 @@ class ScalapackConnector
240246
pzgetri_(&n, A, &ia, &ja, desca, ipiv, work, lwork, iwork, liwork, info);
241247
}
242248

249+
static inline
250+
void gesv(
251+
const int n, const int nrhs,
252+
const std::complex<double> *A, const int ia, const int ja, const int *desca,
253+
int *ipiv, std::complex<double>* B, const int ib, const int jb, const int*descb, int *info)
254+
{
255+
pzgesv_(&n, &nrhs, A, &ia, &ja, desca, ipiv, B, &ib, &jb, descb, info);
256+
}
257+
243258
static inline
244259
void tranu(
245260
const int m, const int n,

source/module_hamilt_lcao/module_tddft/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ if(ENABLE_LCAO)
1313
td_velocity.cpp
1414
td_current.cpp
1515
snap_psibeta_half_tddft.cpp
16+
solve_propagation.cpp
1617
)
1718

1819
add_library(

source/module_hamilt_lcao/module_tddft/evolve_psi.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "norm_psi.h"
1414
#include "propagator.h"
1515
#include "upsi.h"
16+
#include "solve_propagation.h"
1617

1718
#include <complex>
1819

@@ -69,19 +70,30 @@ void evolve_psi(const int nband,
6970
}
7071

7172
// (2)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
72-
73-
/// @brief compute U_operator
74-
/// @input Stmp, Htmp, print_matrix
75-
/// @output U_operator
76-
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
77-
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
73+
if (propagator != 3)
74+
{
75+
/// @brief compute U_operator
76+
/// @input Stmp, Htmp, print_matrix
77+
/// @output U_operator
78+
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
79+
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
80+
}
7881

7982
// (3)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
80-
81-
/// @brief apply U_operator to the wave function of the previous step for new wave function
82-
/// @input U_operator, psi_k_laststep, print_matrix
83-
/// @output psi_k
84-
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
83+
if (propagator != 3)
84+
{
85+
/// @brief apply U_operator to the wave function of the previous step for new wave function
86+
/// @input U_operator, psi_k_laststep, print_matrix
87+
/// @output psi_k
88+
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
89+
}
90+
else
91+
{
92+
/// @brief solve the propagation equation
93+
/// @input Stmp, Htmp, psi_k_laststep
94+
/// @output psi_k
95+
solve_propagation(pv, nband, nlocal, PARAM.mdp.md_dt / ModuleBase::AU_to_FS, Stmp, Htmp, psi_k_laststep, psi_k);
96+
}
8597

8698
// (4)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
8799

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include "solve_propagation.h"
2+
3+
#include <iostream>
4+
5+
#include "module_base/lapack_connector.h"
6+
#include "module_base/scalapack_connector.h"
7+
8+
namespace module_tddft
9+
{
10+
#ifdef __MPI
11+
void solve_propagation(const Parallel_Orbitals* pv,
12+
const int nband,
13+
const int nlocal,
14+
const double dt,
15+
const std::complex<double>* Stmp,
16+
const std::complex<double>* Htmp,
17+
const std::complex<double>* psi_k_laststep,
18+
std::complex<double>* psi_k)
19+
{
20+
// (1) init A,B and copy Htmp to A & B
21+
std::complex<double>* operator_A = new std::complex<double>[pv->nloc];
22+
ModuleBase::GlobalFunc::ZEROS(operator_A, pv->nloc);
23+
BlasConnector::copy(pv->nloc, Htmp, 1, operator_A, 1);
24+
25+
std::complex<double>* operator_B = new std::complex<double>[pv->nloc];
26+
ModuleBase::GlobalFunc::ZEROS(operator_B, pv->nloc);
27+
BlasConnector::copy(pv->nloc, Htmp, 1, operator_B, 1);
28+
29+
// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
30+
// (2) compute operator_A & operator_B by GEADD
31+
// operator_A = Stmp + i*para * Htmp; beta2 = para = 0.25 * dt
32+
// operator_B = Stmp - i*para * Htmp; beta1 = - para = -0.25 * dt
33+
std::complex<double> alpha = {1.0, 0.0};
34+
std::complex<double> beta1 = {0.0, -0.25 * dt};
35+
std::complex<double> beta2 = {0.0, 0.25 * dt};
36+
37+
ScalapackConnector::geadd('N',
38+
nlocal,
39+
nlocal,
40+
alpha,
41+
Stmp,
42+
1,
43+
1,
44+
pv->desc,
45+
beta2,
46+
operator_A,
47+
1,
48+
1,
49+
pv->desc);
50+
ScalapackConnector::geadd('N',
51+
nlocal,
52+
nlocal,
53+
alpha,
54+
Stmp,
55+
1,
56+
1,
57+
pv->desc,
58+
beta1,
59+
operator_B,
60+
1,
61+
1,
62+
pv->desc);
63+
// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
64+
// (3) b = operator_B @ psi_k_laststep
65+
std::complex<double>* tmp_b = new std::complex<double>[pv->nloc_wfc];
66+
ScalapackConnector::gemm('N',
67+
'N',
68+
nlocal,
69+
nband,
70+
nlocal,
71+
1.0,
72+
operator_B,
73+
1,
74+
1,
75+
pv->desc,
76+
psi_k_laststep,
77+
1,
78+
1,
79+
pv->desc_wfc,
80+
0.0,
81+
tmp_b,
82+
1,
83+
1,
84+
pv->desc_wfc);
85+
//get ipiv
86+
int* ipiv = new int[pv->nloc];
87+
int info = 0;
88+
// (4) solve Ac=b
89+
ScalapackConnector::gesv(nlocal,
90+
nband,
91+
operator_A,
92+
1,
93+
1,
94+
pv->desc,
95+
ipiv,
96+
tmp_b,
97+
1,
98+
1,
99+
pv->desc_wfc,
100+
&info);
101+
102+
//copy solution to psi_k
103+
BlasConnector::copy(pv->nloc_wfc, tmp_b, 1, psi_k, 1);
104+
105+
delete []tmp_b;
106+
delete []ipiv;
107+
delete []operator_A;
108+
delete []operator_B;
109+
}
110+
#endif // __MPI
111+
} // namespace module_tddft
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef TD_SOLVE_PROPAGATION_H
2+
#define TD_SOLVE_PROPAGATION_H
3+
4+
#include "module_basis/module_ao/parallel_orbitals.h"
5+
#include <complex>
6+
7+
namespace module_tddft
8+
{
9+
#ifdef __MPI
10+
/**
11+
* @brief solve propagation equation A@c(t+dt) = B@c(t)
12+
*
13+
* @param[in] pv information of parallel
14+
* @param[in] nband number of bands
15+
* @param[in] nlocal number of orbitals
16+
* @param[in] dt time interval
17+
* @param[in] Stmp overlap matrix S(t+dt/2)
18+
* @param[in] Htmp H(t+dt/2)
19+
* @param[in] psi_k_laststep psi of last step
20+
* @param[out] psi_k psi of this step
21+
*/
22+
void solve_propagation(const Parallel_Orbitals* pv,
23+
const int nband,
24+
const int nlocal,
25+
const double dt,
26+
const std::complex<double>* Stmp,
27+
const std::complex<double>* Htmp,
28+
const std::complex<double>* psi_k_laststep,
29+
std::complex<double>* psi_k);
30+
31+
#endif
32+
} // namespace module_tddft
33+
34+
#endif // TD_SOLVE_H

0 commit comments

Comments
 (0)