Skip to content

Commit 55082f5

Browse files
committed
Refactor: move cal_edm_tddft to module_dm
1 parent dcff74d commit 55082f5

File tree

7 files changed

+69
-90
lines changed

7 files changed

+69
-90
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ OBJS_ELECSTAT_LCAO=elecstate_lcao.o\
233233
density_matrix.o\
234234
density_matrix_io.o\
235235
cal_dm_psi.o\
236+
cal_edm_tddft.o\
236237

237238
OBJS_ESOLVER=esolver.o\
238239
esolver_ks.o\
@@ -259,7 +260,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
259260
lcao_others.o\
260261
lcao_init_after_vc.o\
261262
lcao_fun.o\
262-
cal_edm_tddft.o\
263263

264264
OBJS_GINT=gint.o\
265265
gint_gamma_env.o\

source/module_elecstate/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ if(ENABLE_LCAO)
4242
module_dm/density_matrix.cpp
4343
module_dm/density_matrix_io.cpp
4444
module_dm/cal_dm_psi.cpp
45+
module_dm/cal_edm_tddft.cpp
4546
)
4647
endif()
4748

source/module_esolver/cal_edm_tddft.cpp renamed to source/module_elecstate/module_dm/cal_edm_tddft.cpp

Lines changed: 49 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,45 @@
1-
#include "esolver_ks_lcao_tddft.h"
2-
3-
#include "module_io/cal_r_overlap_R.h"
4-
#include "module_io/dipole_io.h"
5-
#include "module_io/td_current_io.h"
6-
#include "module_io/write_HS.h"
7-
#include "module_io/write_HS_R.h"
8-
#include "module_io/write_wfc_nao.h"
9-
10-
//--------------temporary----------------------------
11-
#include "module_base/blas_connector.h"
12-
#include "module_base/global_function.h"
1+
#include "cal_edm_tddft.h"
2+
133
#include "module_base/scalapack_connector.h"
14-
#include "module_base/lapack_connector.h"
15-
#include "module_elecstate/module_charge/symmetry_rho.h"
16-
#include "module_elecstate/occupy.h"
17-
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
18-
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
19-
#include "module_hamilt_lcao/module_tddft/td_velocity.h"
20-
#include "module_hamilt_pw/hamilt_pwdft/global.h"
21-
#include "module_io/print_info.h"
22-
23-
//-----HSolver ElecState Hamilt--------
24-
#include "module_elecstate/elecstate_lcao.h"
25-
#include "module_elecstate/elecstate_lcao_tddft.h"
26-
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
27-
#include "module_hsolver/hsolver_lcao.h"
28-
#include "module_parameter/parameter.h"
29-
#include "module_psi/psi.h"
30-
31-
//-----force& stress-------------------
32-
#include "module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h"
33-
34-
//---------------------------------------------------
35-
36-
namespace ModuleESolver
4+
namespace elecstate
375
{
386

39-
// use the original formula (Hamiltonian matrix) to calculate energy density
40-
// matrix
41-
void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
7+
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
8+
void cal_edm_tddft(Parallel_Orbitals& pv,
9+
elecstate::ElecState* pelec,
10+
K_Vectors& kv,
11+
hamilt::Hamilt<std::complex<double>>* p_hamilt)
4212
{
4313
// mohan add 2024-03-27
4414
const int nlocal = PARAM.globalv.nlocal;
4515
assert(nlocal >= 0);
4616

47-
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
48-
->get_DM()
49-
->EDMK.resize(kv.get_nks());
50-
for (int ik = 0; ik < kv.get_nks(); ++ik) {
51-
52-
p_hamilt->updateHk(ik);
17+
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(pelec);
5318

54-
std::complex<double>* tmp_dmk
55-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_DMK_pointer(ik);
19+
_pelec->get_DM()->EDMK.resize(kv.get_nks());
5620

57-
ModuleBase::ComplexMatrix& tmp_edmk
58-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->EDMK[ik];
59-
60-
const Parallel_Orbitals* tmp_pv
61-
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_paraV_pointer();
21+
for (int ik = 0; ik < kv.get_nks(); ++ik)
22+
{
23+
p_hamilt->updateHk(ik);
24+
std::complex<double>* tmp_dmk = _pelec->get_DM()->get_DMK_pointer(ik);
25+
ModuleBase::ComplexMatrix& tmp_edmk = _pelec->get_DM()->EDMK[ik];
6226

6327
#ifdef __MPI
6428

6529
// mohan add 2024-03-27
6630
//! be careful, the type of nloc is 'long'
6731
//! whether the long type is safe, needs more discussion
68-
const long nloc = this->pv.nloc;
69-
const int ncol = this->pv.ncol;
70-
const int nrow = this->pv.nrow;
32+
const long nloc = pv.nloc;
33+
const int ncol = pv.ncol;
34+
const int nrow = pv.nrow;
7135

7236
tmp_edmk.create(ncol, nrow);
73-
complex<double>* Htmp = new complex<double>[nloc];
74-
complex<double>* Sinv = new complex<double>[nloc];
75-
complex<double>* tmp1 = new complex<double>[nloc];
76-
complex<double>* tmp2 = new complex<double>[nloc];
77-
complex<double>* tmp3 = new complex<double>[nloc];
78-
complex<double>* tmp4 = new complex<double>[nloc];
37+
std::complex<double>* Htmp = new std::complex<double>[nloc];
38+
std::complex<double>* Sinv = new std::complex<double>[nloc];
39+
std::complex<double>* tmp1 = new std::complex<double>[nloc];
40+
std::complex<double>* tmp2 = new std::complex<double>[nloc];
41+
std::complex<double>* tmp3 = new std::complex<double>[nloc];
42+
std::complex<double>* tmp4 = new std::complex<double>[nloc];
7943

8044
ModuleBase::GlobalFunc::ZEROS(Htmp, nloc);
8145
ModuleBase::GlobalFunc::ZEROS(Sinv, nloc);
@@ -86,8 +50,8 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
8650

8751
const int inc = 1;
8852

89-
hamilt::MatrixBlock<complex<double>> h_mat;
90-
hamilt::MatrixBlock<complex<double>> s_mat;
53+
hamilt::MatrixBlock<std::complex<double>> h_mat;
54+
hamilt::MatrixBlock<std::complex<double>> s_mat;
9155

9256
p_hamilt->matrix(h_mat, s_mat);
9357
zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc);
@@ -97,7 +61,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
9761
int info = 0;
9862
const int one_int = 1;
9963

100-
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, this->pv.desc, ipiv.data(), &info);
64+
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info);
10165

10266
int lwork = -1;
10367
int liwork = -1;
@@ -112,7 +76,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
11276
Sinv,
11377
&one_int,
11478
&one_int,
115-
this->pv.desc,
79+
pv.desc,
11680
ipiv.data(),
11781
work.data(),
11882
&lwork,
@@ -129,7 +93,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
12993
Sinv,
13094
&one_int,
13195
&one_int,
132-
this->pv.desc,
96+
pv.desc,
13397
ipiv.data(),
13498
work.data(),
13599
&lwork,
@@ -139,9 +103,9 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
139103

140104
const char N_char = 'N';
141105
const char T_char = 'T';
142-
const complex<double> one_float = {1.0, 0.0};
143-
const complex<double> zero_float = {0.0, 0.0};
144-
const complex<double> half_float = {0.5, 0.0};
106+
const std::complex<double> one_float = {1.0, 0.0};
107+
const std::complex<double> zero_float = {0.0, 0.0};
108+
const std::complex<double> half_float = {0.5, 0.0};
145109

146110
pzgemm_(&N_char,
147111
&N_char,
@@ -152,16 +116,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
152116
Htmp,
153117
&one_int,
154118
&one_int,
155-
this->pv.desc,
119+
pv.desc,
156120
Sinv,
157121
&one_int,
158122
&one_int,
159-
this->pv.desc,
123+
pv.desc,
160124
&zero_float,
161125
tmp1,
162126
&one_int,
163127
&one_int,
164-
this->pv.desc);
128+
pv.desc);
165129

166130
pzgemm_(&T_char,
167131
&N_char,
@@ -172,16 +136,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
172136
tmp1,
173137
&one_int,
174138
&one_int,
175-
this->pv.desc,
139+
pv.desc,
176140
tmp_dmk,
177141
&one_int,
178142
&one_int,
179-
this->pv.desc,
143+
pv.desc,
180144
&zero_float,
181145
tmp2,
182146
&one_int,
183147
&one_int,
184-
this->pv.desc);
148+
pv.desc);
185149

186150
pzgemm_(&N_char,
187151
&N_char,
@@ -192,16 +156,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
192156
Sinv,
193157
&one_int,
194158
&one_int,
195-
this->pv.desc,
159+
pv.desc,
196160
Htmp,
197161
&one_int,
198162
&one_int,
199-
this->pv.desc,
163+
pv.desc,
200164
&zero_float,
201165
tmp3,
202166
&one_int,
203167
&one_int,
204-
this->pv.desc);
168+
pv.desc);
205169

206170
pzgemm_(&N_char,
207171
&T_char,
@@ -212,16 +176,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
212176
tmp_dmk,
213177
&one_int,
214178
&one_int,
215-
this->pv.desc,
179+
pv.desc,
216180
tmp3,
217181
&one_int,
218182
&one_int,
219-
this->pv.desc,
183+
pv.desc,
220184
&zero_float,
221185
tmp4,
222186
&one_int,
223187
&one_int,
224-
this->pv.desc);
188+
pv.desc);
225189

226190
pzgeadd_(&N_char,
227191
&nlocal,
@@ -230,12 +194,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
230194
tmp2,
231195
&one_int,
232196
&one_int,
233-
this->pv.desc,
197+
pv.desc,
234198
&half_float,
235199
tmp4,
236200
&one_int,
237201
&one_int,
238-
this->pv.desc);
202+
pv.desc);
239203

240204
zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc);
241205

@@ -247,12 +211,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
247211
delete[] tmp4;
248212
#else
249213
// for serial version
250-
tmp_edmk.create(this->pv.ncol, this->pv.nrow);
214+
tmp_edmk.create(pv.ncol, pv.nrow);
251215
ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
252216
ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);
253217

254-
hamilt::MatrixBlock<complex<double>> h_mat;
255-
hamilt::MatrixBlock<complex<double>> s_mat;
218+
hamilt::MatrixBlock<std::complex<double>> h_mat;
219+
hamilt::MatrixBlock<std::complex<double>> s_mat;
256220

257221
p_hamilt->matrix(h_mat, s_mat);
258222
// cout<<"hmat "<<h_mat.p[0]<<endl;
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef CAL_DM_PSI_H
2+
#define CAL_DM_PSI_H
3+
4+
#include "module_basis/module_ao/parallel_orbitals.h"
5+
#include "module_cell/klist.h"
6+
#include "module_elecstate/elecstate_lcao.h"
7+
#include "module_hamilt_general/hamilt.h"
8+
9+
namespace elecstate
10+
{
11+
void cal_edm_tddft(Parallel_Orbitals& pv,
12+
elecstate::ElecState* pelec,
13+
K_Vectors& kv,
14+
hamilt::Hamilt<std::complex<double>>* p_hamilt);
15+
} // namespace elecstate
16+
#endif

source/module_esolver/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ if(ENABLE_LCAO)
2626
lcao_others.cpp
2727
lcao_init_after_vc.cpp
2828
lcao_fun.cpp
29-
cal_edm_tddft.cpp
3029
)
3130
endif()
3231

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "module_base/lapack_connector.h"
1414
#include "module_base/scalapack_connector.h"
1515
#include "module_elecstate/module_charge/symmetry_rho.h"
16+
#include "module_elecstate/module_dm/cal_edm_tddft.h"
1617
#include "module_elecstate/occupy.h"
1718
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
1819
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
@@ -358,7 +359,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
358359
// calculate energy density matrix for tddft
359360
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
360361
{
361-
this->cal_edm_tddft();
362+
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
362363
}
363364
}
364365

source/module_esolver/esolver_ks_lcao_tddft.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
3737
virtual void iter_finish(const int istep, int& iter) override;
3838

3939
virtual void after_scf(const int istep) override;
40-
41-
void cal_edm_tddft();
4240
};
4341

4442
} // namespace ModuleESolver

0 commit comments

Comments
 (0)