Skip to content

Commit a95eb5b

Browse files
authored
Refactor: move cal_edm_tddft to module_dm (#5485)
* Refactor: move cal_edm_tddft to module_dm * update head file * add lapack_connector.h in cal_edm_tddft.cpp
1 parent 8e8cad2 commit a95eb5b

File tree

7 files changed

+70
-90
lines changed

7 files changed

+70
-90
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ OBJS_ELECSTAT_LCAO=elecstate_lcao.o\
233233
density_matrix.o\
234234
density_matrix_io.o\
235235
cal_dm_psi.o\
236+
cal_edm_tddft.o\
236237

237238
OBJS_ESOLVER=esolver.o\
238239
esolver_ks.o\
@@ -259,7 +260,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
259260
lcao_others.o\
260261
lcao_init_after_vc.o\
261262
lcao_fun.o\
262-
cal_edm_tddft.o\
263263

264264
OBJS_GINT=gint.o\
265265
gint_gamma_env.o\

source/module_elecstate/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ if(ENABLE_LCAO)
4242
module_dm/density_matrix.cpp
4343
module_dm/density_matrix_io.cpp
4444
module_dm/cal_dm_psi.cpp
45+
module_dm/cal_edm_tddft.cpp
4546
)
4647
endif()
4748

source/module_esolver/cal_edm_tddft.cpp renamed to source/module_elecstate/module_dm/cal_edm_tddft.cpp

Lines changed: 50 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,46 @@
1-
#include "esolver_ks_lcao_tddft.h"
2-
3-
#include "module_io/cal_r_overlap_R.h"
4-
#include "module_io/dipole_io.h"
5-
#include "module_io/td_current_io.h"
6-
#include "module_io/write_HS.h"
7-
#include "module_io/write_HS_R.h"
8-
#include "module_io/write_wfc_nao.h"
9-
10-
//--------------temporary----------------------------
11-
#include "module_base/blas_connector.h"
12-
#include "module_base/global_function.h"
13-
#include "module_base/scalapack_connector.h"
1+
#include "cal_edm_tddft.h"
2+
143
#include "module_base/lapack_connector.h"
15-
#include "module_elecstate/module_charge/symmetry_rho.h"
16-
#include "module_elecstate/occupy.h"
17-
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
18-
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
19-
#include "module_hamilt_lcao/module_tddft/td_velocity.h"
20-
#include "module_hamilt_pw/hamilt_pwdft/global.h"
21-
#include "module_io/print_info.h"
22-
23-
//-----HSolver ElecState Hamilt--------
24-
#include "module_elecstate/elecstate_lcao.h"
25-
#include "module_elecstate/elecstate_lcao_tddft.h"
26-
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
27-
#include "module_hsolver/hsolver_lcao.h"
28-
#include "module_parameter/parameter.h"
29-
#include "module_psi/psi.h"
30-
31-
//-----force& stress-------------------
32-
#include "module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h"
33-
34-
//---------------------------------------------------
35-
36-
namespace ModuleESolver
4+
#include "module_base/scalapack_connector.h"
5+
namespace elecstate
376
{
387

39-
// use the original formula (Hamiltonian matrix) to calculate energy density
40-
// matrix
41-
void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
8+
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
9+
void cal_edm_tddft(Parallel_Orbitals& pv,
10+
elecstate::ElecState* pelec,
11+
K_Vectors& kv,
12+
hamilt::Hamilt<std::complex<double>>* p_hamilt)
4213
{
4314
// mohan add 2024-03-27
4415
const int nlocal = PARAM.globalv.nlocal;
4516
assert(nlocal >= 0);
4617

47-
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
48-
->get_DM()
49-
->EDMK.resize(kv.get_nks());
50-
for (int ik = 0; ik < kv.get_nks(); ++ik) {
18+
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(pelec);
5119

52-
p_hamilt->updateHk(ik);
20+
_pelec->get_DM()->EDMK.resize(kv.get_nks());
5321

54-
std::complex<double>* tmp_dmk
55-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_DMK_pointer(ik);
56-
57-
ModuleBase::ComplexMatrix& tmp_edmk
58-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->EDMK[ik];
59-
60-
const Parallel_Orbitals* tmp_pv
61-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_paraV_pointer();
22+
for (int ik = 0; ik < kv.get_nks(); ++ik)
23+
{
24+
p_hamilt->updateHk(ik);
25+
std::complex<double>* tmp_dmk = _pelec->get_DM()->get_DMK_pointer(ik);
26+
ModuleBase::ComplexMatrix& tmp_edmk = _pelec->get_DM()->EDMK[ik];
6227

6328
#ifdef __MPI
6429

6530
// mohan add 2024-03-27
6631
//! be careful, the type of nloc is 'long'
6732
//! whether the long type is safe, needs more discussion
68-
const long nloc = this->pv.nloc;
69-
const int ncol = this->pv.ncol;
70-
const int nrow = this->pv.nrow;
33+
const long nloc = pv.nloc;
34+
const int ncol = pv.ncol;
35+
const int nrow = pv.nrow;
7136

7237
tmp_edmk.create(ncol, nrow);
73-
complex<double>* Htmp = new complex<double>[nloc];
74-
complex<double>* Sinv = new complex<double>[nloc];
75-
complex<double>* tmp1 = new complex<double>[nloc];
76-
complex<double>* tmp2 = new complex<double>[nloc];
77-
complex<double>* tmp3 = new complex<double>[nloc];
78-
complex<double>* tmp4 = new complex<double>[nloc];
38+
std::complex<double>* Htmp = new std::complex<double>[nloc];
39+
std::complex<double>* Sinv = new std::complex<double>[nloc];
40+
std::complex<double>* tmp1 = new std::complex<double>[nloc];
41+
std::complex<double>* tmp2 = new std::complex<double>[nloc];
42+
std::complex<double>* tmp3 = new std::complex<double>[nloc];
43+
std::complex<double>* tmp4 = new std::complex<double>[nloc];
7944

8045
ModuleBase::GlobalFunc::ZEROS(Htmp, nloc);
8146
ModuleBase::GlobalFunc::ZEROS(Sinv, nloc);
@@ -86,8 +51,8 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
8651

8752
const int inc = 1;
8853

89-
hamilt::MatrixBlock<complex<double>> h_mat;
90-
hamilt::MatrixBlock<complex<double>> s_mat;
54+
hamilt::MatrixBlock<std::complex<double>> h_mat;
55+
hamilt::MatrixBlock<std::complex<double>> s_mat;
9156

9257
p_hamilt->matrix(h_mat, s_mat);
9358
zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc);
@@ -97,7 +62,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
9762
int info = 0;
9863
const int one_int = 1;
9964

100-
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, this->pv.desc, ipiv.data(), &info);
65+
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info);
10166

10267
int lwork = -1;
10368
int liwork = -1;
@@ -112,7 +77,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
11277
Sinv,
11378
&one_int,
11479
&one_int,
115-
this->pv.desc,
80+
pv.desc,
11681
ipiv.data(),
11782
work.data(),
11883
&lwork,
@@ -129,7 +94,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
12994
Sinv,
13095
&one_int,
13196
&one_int,
132-
this->pv.desc,
97+
pv.desc,
13398
ipiv.data(),
13499
work.data(),
135100
&lwork,
@@ -139,9 +104,9 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
139104

140105
const char N_char = 'N';
141106
const char T_char = 'T';
142-
const complex<double> one_float = {1.0, 0.0};
143-
const complex<double> zero_float = {0.0, 0.0};
144-
const complex<double> half_float = {0.5, 0.0};
107+
const std::complex<double> one_float = {1.0, 0.0};
108+
const std::complex<double> zero_float = {0.0, 0.0};
109+
const std::complex<double> half_float = {0.5, 0.0};
145110

146111
pzgemm_(&N_char,
147112
&N_char,
@@ -152,16 +117,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
152117
Htmp,
153118
&one_int,
154119
&one_int,
155-
this->pv.desc,
120+
pv.desc,
156121
Sinv,
157122
&one_int,
158123
&one_int,
159-
this->pv.desc,
124+
pv.desc,
160125
&zero_float,
161126
tmp1,
162127
&one_int,
163128
&one_int,
164-
this->pv.desc);
129+
pv.desc);
165130

166131
pzgemm_(&T_char,
167132
&N_char,
@@ -172,16 +137,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
172137
tmp1,
173138
&one_int,
174139
&one_int,
175-
this->pv.desc,
140+
pv.desc,
176141
tmp_dmk,
177142
&one_int,
178143
&one_int,
179-
this->pv.desc,
144+
pv.desc,
180145
&zero_float,
181146
tmp2,
182147
&one_int,
183148
&one_int,
184-
this->pv.desc);
149+
pv.desc);
185150

186151
pzgemm_(&N_char,
187152
&N_char,
@@ -192,16 +157,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
192157
Sinv,
193158
&one_int,
194159
&one_int,
195-
this->pv.desc,
160+
pv.desc,
196161
Htmp,
197162
&one_int,
198163
&one_int,
199-
this->pv.desc,
164+
pv.desc,
200165
&zero_float,
201166
tmp3,
202167
&one_int,
203168
&one_int,
204-
this->pv.desc);
169+
pv.desc);
205170

206171
pzgemm_(&N_char,
207172
&T_char,
@@ -212,16 +177,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
212177
tmp_dmk,
213178
&one_int,
214179
&one_int,
215-
this->pv.desc,
180+
pv.desc,
216181
tmp3,
217182
&one_int,
218183
&one_int,
219-
this->pv.desc,
184+
pv.desc,
220185
&zero_float,
221186
tmp4,
222187
&one_int,
223188
&one_int,
224-
this->pv.desc);
189+
pv.desc);
225190

226191
pzgeadd_(&N_char,
227192
&nlocal,
@@ -230,12 +195,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
230195
tmp2,
231196
&one_int,
232197
&one_int,
233-
this->pv.desc,
198+
pv.desc,
234199
&half_float,
235200
tmp4,
236201
&one_int,
237202
&one_int,
238-
this->pv.desc);
203+
pv.desc);
239204

240205
zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc);
241206

@@ -247,12 +212,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
247212
delete[] tmp4;
248213
#else
249214
// for serial version
250-
tmp_edmk.create(this->pv.ncol, this->pv.nrow);
215+
tmp_edmk.create(pv.ncol, pv.nrow);
251216
ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
252217
ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);
253218

254-
hamilt::MatrixBlock<complex<double>> h_mat;
255-
hamilt::MatrixBlock<complex<double>> s_mat;
219+
hamilt::MatrixBlock<std::complex<double>> h_mat;
220+
hamilt::MatrixBlock<std::complex<double>> s_mat;
256221

257222
p_hamilt->matrix(h_mat, s_mat);
258223
// cout<<"hmat "<<h_mat.p[0]<<endl;
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef CAL_EDM_TDDFT_H
2+
#define CAL_EDM_TDDFT_H
3+
4+
#include "module_basis/module_ao/parallel_orbitals.h"
5+
#include "module_cell/klist.h"
6+
#include "module_elecstate/elecstate_lcao.h"
7+
#include "module_hamilt_general/hamilt.h"
8+
9+
namespace elecstate
10+
{
11+
void cal_edm_tddft(Parallel_Orbitals& pv,
12+
elecstate::ElecState* pelec,
13+
K_Vectors& kv,
14+
hamilt::Hamilt<std::complex<double>>* p_hamilt);
15+
} // namespace elecstate
16+
#endif // CAL_EDM_TDDFT_H

source/module_esolver/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ if(ENABLE_LCAO)
2626
lcao_others.cpp
2727
lcao_init_after_vc.cpp
2828
lcao_fun.cpp
29-
cal_edm_tddft.cpp
3029
)
3130
endif()
3231

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "module_base/lapack_connector.h"
1414
#include "module_base/scalapack_connector.h"
1515
#include "module_elecstate/module_charge/symmetry_rho.h"
16+
#include "module_elecstate/module_dm/cal_edm_tddft.h"
1617
#include "module_elecstate/occupy.h"
1718
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
1819
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
@@ -358,7 +359,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
358359
// calculate energy density matrix for tddft
359360
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
360361
{
361-
this->cal_edm_tddft();
362+
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
362363
}
363364
}
364365

source/module_esolver/esolver_ks_lcao_tddft.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
3737
virtual void iter_finish(const int istep, int& iter) override;
3838

3939
virtual void after_scf(const int istep) override;
40-
41-
void cal_edm_tddft();
4240
};
4341

4442
} // namespace ModuleESolver

0 commit comments

Comments
 (0)