Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ OBJS_ELECSTAT_LCAO=elecstate_lcao.o\
density_matrix.o\
density_matrix_io.o\
cal_dm_psi.o\
cal_edm_tddft.o\

OBJS_ESOLVER=esolver.o\
esolver_ks.o\
Expand All @@ -259,7 +260,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
lcao_others.o\
lcao_init_after_vc.o\
lcao_fun.o\
cal_edm_tddft.o\

OBJS_GINT=gint.o\
gint_gamma_env.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_elecstate/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ if(ENABLE_LCAO)
module_dm/density_matrix.cpp
module_dm/density_matrix_io.cpp
module_dm/cal_dm_psi.cpp
module_dm/cal_edm_tddft.cpp
)
endif()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,81 +1,46 @@
#include "esolver_ks_lcao_tddft.h"

#include "module_io/cal_r_overlap_R.h"
#include "module_io/dipole_io.h"
#include "module_io/td_current_io.h"
#include "module_io/write_HS.h"
#include "module_io/write_HS_R.h"
#include "module_io/write_wfc_nao.h"

//--------------temporary----------------------------
#include "module_base/blas_connector.h"
#include "module_base/global_function.h"
#include "module_base/scalapack_connector.h"
#include "cal_edm_tddft.h"

#include "module_base/lapack_connector.h"
#include "module_elecstate/module_charge/symmetry_rho.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
#include "module_hamilt_lcao/module_tddft/td_velocity.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/print_info.h"

//-----HSolver ElecState Hamilt--------
#include "module_elecstate/elecstate_lcao.h"
#include "module_elecstate/elecstate_lcao_tddft.h"
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
#include "module_hsolver/hsolver_lcao.h"
#include "module_parameter/parameter.h"
#include "module_psi/psi.h"

//-----force& stress-------------------
#include "module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h"

//---------------------------------------------------

namespace ModuleESolver
#include "module_base/scalapack_connector.h"
namespace elecstate
{

// use the original formula (Hamiltonian matrix) to calculate energy density
// matrix
void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
void cal_edm_tddft(Parallel_Orbitals& pv,
elecstate::ElecState* pelec,
K_Vectors& kv,
hamilt::Hamilt<std::complex<double>>* p_hamilt)
{
// mohan add 2024-03-27
const int nlocal = PARAM.globalv.nlocal;
assert(nlocal >= 0);

dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
->get_DM()
->EDMK.resize(kv.get_nks());
for (int ik = 0; ik < kv.get_nks(); ++ik) {
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(pelec);

p_hamilt->updateHk(ik);
_pelec->get_DM()->EDMK.resize(kv.get_nks());

std::complex<double>* tmp_dmk
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_DMK_pointer(ik);

ModuleBase::ComplexMatrix& tmp_edmk
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->EDMK[ik];

const Parallel_Orbitals* tmp_pv
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_paraV_pointer();
for (int ik = 0; ik < kv.get_nks(); ++ik)
{
p_hamilt->updateHk(ik);
std::complex<double>* tmp_dmk = _pelec->get_DM()->get_DMK_pointer(ik);
ModuleBase::ComplexMatrix& tmp_edmk = _pelec->get_DM()->EDMK[ik];

#ifdef __MPI

// mohan add 2024-03-27
//! be careful, the type of nloc is 'long'
//! whether the long type is safe, needs more discussion
const long nloc = this->pv.nloc;
const int ncol = this->pv.ncol;
const int nrow = this->pv.nrow;
const long nloc = pv.nloc;
const int ncol = pv.ncol;
const int nrow = pv.nrow;

tmp_edmk.create(ncol, nrow);
complex<double>* Htmp = new complex<double>[nloc];
complex<double>* Sinv = new complex<double>[nloc];
complex<double>* tmp1 = new complex<double>[nloc];
complex<double>* tmp2 = new complex<double>[nloc];
complex<double>* tmp3 = new complex<double>[nloc];
complex<double>* tmp4 = new complex<double>[nloc];
std::complex<double>* Htmp = new std::complex<double>[nloc];
std::complex<double>* Sinv = new std::complex<double>[nloc];
std::complex<double>* tmp1 = new std::complex<double>[nloc];
std::complex<double>* tmp2 = new std::complex<double>[nloc];
std::complex<double>* tmp3 = new std::complex<double>[nloc];
std::complex<double>* tmp4 = new std::complex<double>[nloc];

ModuleBase::GlobalFunc::ZEROS(Htmp, nloc);
ModuleBase::GlobalFunc::ZEROS(Sinv, nloc);
Expand All @@ -86,8 +51,8 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()

const int inc = 1;

hamilt::MatrixBlock<complex<double>> h_mat;
hamilt::MatrixBlock<complex<double>> s_mat;
hamilt::MatrixBlock<std::complex<double>> h_mat;
hamilt::MatrixBlock<std::complex<double>> s_mat;

p_hamilt->matrix(h_mat, s_mat);
zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc);
Expand All @@ -97,7 +62,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
int info = 0;
const int one_int = 1;

pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, this->pv.desc, ipiv.data(), &info);
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info);

int lwork = -1;
int liwork = -1;
Expand All @@ -112,7 +77,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
ipiv.data(),
work.data(),
&lwork,
Expand All @@ -129,7 +94,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
ipiv.data(),
work.data(),
&lwork,
Expand All @@ -139,9 +104,9 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()

const char N_char = 'N';
const char T_char = 'T';
const complex<double> one_float = {1.0, 0.0};
const complex<double> zero_float = {0.0, 0.0};
const complex<double> half_float = {0.5, 0.0};
const std::complex<double> one_float = {1.0, 0.0};
const std::complex<double> zero_float = {0.0, 0.0};
const std::complex<double> half_float = {0.5, 0.0};

pzgemm_(&N_char,
&N_char,
Expand All @@ -152,16 +117,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Htmp,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp1,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&T_char,
&N_char,
Expand All @@ -172,16 +137,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp1,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
tmp_dmk,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp2,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&N_char,
&N_char,
Expand All @@ -192,16 +157,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
Htmp,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp3,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&N_char,
&T_char,
Expand All @@ -212,16 +177,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp_dmk,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
tmp3,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp4,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgeadd_(&N_char,
&nlocal,
Expand All @@ -230,12 +195,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp2,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&half_float,
tmp4,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc);

Expand All @@ -247,12 +212,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
delete[] tmp4;
#else
// for serial version
tmp_edmk.create(this->pv.ncol, this->pv.nrow);
tmp_edmk.create(pv.ncol, pv.nrow);
ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);

hamilt::MatrixBlock<complex<double>> h_mat;
hamilt::MatrixBlock<complex<double>> s_mat;
hamilt::MatrixBlock<std::complex<double>> h_mat;
hamilt::MatrixBlock<std::complex<double>> s_mat;

p_hamilt->matrix(h_mat, s_mat);
// cout<<"hmat "<<h_mat.p[0]<<endl;
Expand Down
16 changes: 16 additions & 0 deletions source/module_elecstate/module_dm/cal_edm_tddft.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef CAL_EDM_TDDFT_H
#define CAL_EDM_TDDFT_H

#include "module_basis/module_ao/parallel_orbitals.h"
#include "module_cell/klist.h"
#include "module_elecstate/elecstate_lcao.h"
#include "module_hamilt_general/hamilt.h"

namespace elecstate
{
void cal_edm_tddft(Parallel_Orbitals& pv,
elecstate::ElecState* pelec,
K_Vectors& kv,
hamilt::Hamilt<std::complex<double>>* p_hamilt);
} // namespace elecstate
#endif // CAL_EDM_TDDFT_H
1 change: 0 additions & 1 deletion source/module_esolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ if(ENABLE_LCAO)
lcao_others.cpp
lcao_init_after_vc.cpp
lcao_fun.cpp
cal_edm_tddft.cpp
)
endif()

Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "module_base/lapack_connector.h"
#include "module_base/scalapack_connector.h"
#include "module_elecstate/module_charge/symmetry_rho.h"
#include "module_elecstate/module_dm/cal_edm_tddft.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
Expand Down Expand Up @@ -358,7 +359,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
// calculate energy density matrix for tddft
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
{
this->cal_edm_tddft();
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
}
}

Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao_tddft.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
virtual void iter_finish(const int istep, int& iter) override;

virtual void after_scf(const int istep) override;

void cal_edm_tddft();
};

} // namespace ModuleESolver
Expand Down
Loading