Skip to content

Commit 3bfc989

Browse files
committed
Change cal_gdmx into template function.
1 parent 44e83c7 commit 3bfc989

File tree

9 files changed

+96
-283
lines changed

9 files changed

+96
-283
lines changed

source/Makefile.Objects

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ OBJS_DEEPKS=LCAO_deepks.o\
202202
LCAO_deepks_interface.o\
203203
orbital_precalc.o\
204204
cal_gdmx.o\
205-
cal_gdmx_k.o\
206205
cal_gedm.o\
207206
cal_gvx.o\
208207
cal_descriptor.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ void Force_Stress_LCAO<T>::getForceStress(const bool isforce,
522522
{
523523
const std::vector<std::vector<double>>& dm_gamma
524524
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
525-
GlobalC::ld.cal_gdmx(dm_gamma[0], ucell, orb, GlobalC::GridD, isstress);
525+
GlobalC::ld.cal_gdmx(dm_gamma, ucell, orb, GlobalC::GridD, kv.get_nks(), kv.kvec_d, isstress);
526526
}
527527
else
528528
{
@@ -531,7 +531,7 @@ void Force_Stress_LCAO<T>::getForceStress(const bool isforce,
531531
->get_DM()
532532
->get_DMK_vector();
533533

534-
GlobalC::ld.cal_gdmx_k(dm_k,
534+
GlobalC::ld.cal_gdmx(dm_k,
535535
ucell,
536536
orb,
537537
GlobalC::GridD,

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ if(ENABLE_DEEPKS)
1414
LCAO_deepks_interface.cpp
1515
orbital_precalc.cpp
1616
cal_gdmx.cpp
17-
cal_gdmx_k.cpp
1817
cal_gedm.cpp
1918
cal_gvx.cpp
2019
cal_descriptor.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ class LCAO_Deepks
287287
// 3. check_projected_dm, which prints pdm to descriptor.dat
288288

289289
// 4. cal_gdmx, calculating gdmx (and optionally gdm_epsl for stress) for gamma point
290-
// 5. cal_gdmx_k, counterpart of 3, for multi-k
291-
// 6. check_gdmx, which prints gdmx to a series of .dat files
290+
// 5. check_gdmx, which prints gdmx to a series of .dat files
292291

293292
public:
294293
/**
@@ -323,21 +322,16 @@ class LCAO_Deepks
323322

324323
// calculate the gradient of pdm with regard to atomic positions
325324
// d/dX D_{Inl,mm'}
325+
template <typename TK>
326326
void cal_gdmx( // const ModuleBase::matrix& dm,
327-
const std::vector<double>& dm,
328-
const UnitCell& ucell,
329-
const LCAO_Orbitals& orb,
330-
Grid_Driver& GridD,
331-
const bool isstress);
332-
333-
void cal_gdmx_k( // const std::vector<ModuleBase::ComplexMatrix>& dm,
334-
const std::vector<std::vector<std::complex<double>>>& dm,
327+
const std::vector<std::vector<TK>>& dm,
335328
const UnitCell& ucell,
336329
const LCAO_Orbitals& orb,
337330
Grid_Driver& GridD,
338331
const int nks,
339332
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
340333
const bool isstress);
334+
341335
void check_gdmx(const int nat);
342336

343337
/**

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ void LCAO_Deepks_Interface<TK,TR>::out_deepks_labels(const double& etot,
9797
elecstate::cal_dm(ParaV, wg_hl, psi, dm_bandgap[ib]);
9898
}
9999
}
100-
101-
ld->cal_orbital_precalc<TK,TH>(dm_bandgap, nat, nks, kvec_d, ucell, orb, GridD);
102100
}
103101
else // for multi-k
104102
{
@@ -116,9 +114,9 @@ void LCAO_Deepks_Interface<TK,TR>::out_deepks_labels(const double& etot,
116114
dm_bandgap[ib].resize(nks);
117115
elecstate::cal_dm(ParaV, wg_hl, psi, dm_bandgap[ib]);
118116
}
119-
// ld->cal_o_delta(dm_bandgap, ParaV, nks);
120-
ld->cal_orbital_precalc<TK,TH>(dm_bandgap, nat, nks, kvec_d, ucell, orb, GridD);
121117
}
118+
119+
ld->cal_orbital_precalc<TK,TH>(dm_bandgap, nat, nks, kvec_d, ucell, orb, GridD);
122120
ld->cal_o_delta(dm_bandgap, nks);
123121

124122
// save obase and orbital_precalc

source/module_hamilt_lcao/module_deepks/LCAO_deepks_pdm.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
//3. check_projected_dm, which prints pdm to descriptor.dat
1515

1616
//4. cal_gdmx, calculating gdmx (and optionally gdm_epsl for stress) for gamma point
17-
//5. cal_gdmx_k, counterpart of 3, for multi-k
18-
//6. check_gdmx, which prints gdmx to a series of .dat files
17+
//5. check_gdmx, which prints gdmx to a series of .dat files
1918

2019
#ifdef __DEEPKS
2120

source/module_hamilt_lcao/module_deepks/cal_gdmx.cpp

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
/// be calculated:
1515
/// gdm_epsl = d/d\epsilon_{ab} *
1616
/// sum_{mu,nu} rho_{mu,nu} <chi_mu|alpha_m><alpha_m'|chi_nu>
17-
void LCAO_Deepks::cal_gdmx(const std::vector<double>& dm,
17+
template <typename TK>
18+
void LCAO_Deepks::cal_gdmx(const std::vector<std::vector<TK>>& dm,
1819
const UnitCell &ucell,
1920
const LCAO_Orbitals &orb,
2021
Grid_Driver& GridD,
22+
const int nks,
23+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
2124
const bool isstress)
2225
{
2326
ModuleBase::TITLE("LCAO_Deepks", "cal_gdmx");
@@ -70,6 +73,8 @@ void LCAO_Deepks::cal_gdmx(const std::vector<double>& dm,
7073
const int nw1_tot = atom1->nw*PARAM.globalv.npol;
7174
const double Rcut_AO1 = orb.Phi[T1].getRcut();
7275

76+
ModuleBase::Vector3<double> dR1(GridD.getBox(ad1).x, GridD.getBox(ad1).y, GridD.getBox(ad1).z);
77+
7378
for (int ad2=0; ad2 < GridD.getAdjacentNum()+1 ; ad2++)
7479
{
7580
const int T2 = GridD.getType(ad2);
@@ -79,6 +84,7 @@ void LCAO_Deepks::cal_gdmx(const std::vector<double>& dm,
7984
const ModuleBase::Vector3<double> tau2 = GridD.getAdjacentTau(ad2);
8085
const Atom* atom2 = &ucell.atoms[T2];
8186
const int nw2_tot = atom2->nw*PARAM.globalv.npol;
87+
ModuleBase::Vector3<double> dR2(GridD.getBox(ad2).x, GridD.getBox(ad2).y, GridD.getBox(ad2).z);
8288

8389
const double Rcut_AO2 = orb.Phi[T2].getRcut();
8490
const double dist1 = (tau1-tau0).norm() * ucell.lat0;
@@ -104,25 +110,68 @@ void LCAO_Deepks::cal_gdmx(const std::vector<double>& dm,
104110
auto row_indexes = pv->get_indexes_row(ibt1);
105111
auto col_indexes = pv->get_indexes_col(ibt2);
106112
if(row_indexes.size() * col_indexes.size() == 0) continue;
107-
108-
hamilt::AtomPair<double> dm_pair(ibt1, ibt2, 0, 0, 0, pv);
109-
dm_pair.allocate(nullptr, 1);
110-
if(ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver))
113+
114+
double* dm_current;
115+
int dRx, dRy, dRz;
116+
if constexpr (std::is_same<TK, double>::value)
111117
{
112-
dm_pair.add_from_matrix(dm.data(), pv->get_row_size(), 1.0, 1);
118+
dRx = 0;
119+
dRy = 0;
120+
dRz = 0;
113121
}
114122
else
115123
{
116-
dm_pair.add_from_matrix(dm.data(), pv->get_col_size(), 1.0, 0);
124+
dRx = (dR2-dR1).x;
125+
dRy = (dR2-dR1).y;
126+
dRz = (dR2-dR1).z;
117127
}
118-
const double* dm_current = dm_pair.get_pointer();
128+
hamilt::AtomPair<double> dm_pair(ibt1, ibt2, dRx, dRy, dRz, pv);
129+
dm_pair.allocate(nullptr, 1);
130+
for(int ik=0;ik<nks;ik++)
131+
{
132+
TK kphase;
133+
if constexpr (std::is_same<TK, double>::value)
134+
{
135+
kphase = 1.0;
136+
}
137+
else
138+
{
139+
const double arg = - (kvec_d[ik] * (dR2-dR1) ) * ModuleBase::TWO_PI;
140+
double sinp, cosp;
141+
ModuleBase::libm::sincos(arg, &sinp, &cosp);
142+
kphase = TK(cosp, sinp);
143+
}
144+
if(ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver))
145+
{
146+
dm_pair.add_from_matrix(dm[ik].data(), pv->get_row_size(), kphase, 1);
147+
}
148+
else
149+
{
150+
dm_pair.add_from_matrix(dm[ik].data(), pv->get_col_size(), kphase, 0);
151+
}
152+
}
153+
154+
dm_current = dm_pair.get_pointer();
119155

156+
key_tuple key_1(ibt1,dR1.x,dR1.y,dR1.z);
157+
key_tuple key_2(ibt2,dR2.x,dR2.y,dR2.z);
120158
for (int iw1=0; iw1<row_indexes.size(); ++iw1)
121159
{
122160
for (int iw2=0; iw2<col_indexes.size(); ++iw2)
123161
{
124-
std::vector<double> nlm1 = this->nlm_save[iat][ad1][row_indexes[iw1]][0];
125-
std::vector<std::vector<double>> nlm2 = this->nlm_save[iat][ad2][col_indexes[iw2]];
162+
std::vector<double> nlm1;
163+
std::vector<std::vector<double>> nlm2;
164+
165+
if constexpr (std::is_same<TK, double>::value)
166+
{
167+
nlm1 = this->nlm_save[iat][ad1][row_indexes[iw1]][0];
168+
nlm2 = this->nlm_save[iat][ad2][col_indexes[iw2]];
169+
}
170+
else
171+
{
172+
nlm1 = this->nlm_save_k[iat][key_1][row_indexes[iw1]][0];
173+
nlm2 = this->nlm_save_k[iat][key_2][col_indexes[iw2]];
174+
}
126175

127176
assert(nlm1.size()==nlm2[0].size());
128177

@@ -178,8 +227,16 @@ void LCAO_Deepks::cal_gdmx(const std::vector<double>& dm,
178227
assert(ib==nlm1.size());
179228
if (isstress)
180229
{
181-
nlm1 = this->nlm_save[iat][ad2][col_indexes[iw2]][0];
182-
nlm2 = this->nlm_save[iat][ad1][row_indexes[iw1]];
230+
if constexpr (std::is_same<TK, double>::value)
231+
{
232+
nlm1 = this->nlm_save[iat][ad2][col_indexes[iw2]][0];
233+
nlm2 = this->nlm_save[iat][ad1][row_indexes[iw1]];
234+
}
235+
else
236+
{
237+
nlm1 = this->nlm_save_k[iat][key_2][col_indexes[iw2]][0];
238+
nlm2 = this->nlm_save_k[iat][key_1][row_indexes[iw1]];
239+
}
183240

184241
assert(nlm1.size()==nlm2[0].size());
185242
int ib=0;
@@ -278,4 +335,20 @@ void LCAO_Deepks::check_gdmx(const int nat)
278335
}
279336
}
280337

338+
template void LCAO_Deepks::cal_gdmx<double>(const std::vector<std::vector<double>>& dm,
339+
const UnitCell &ucell,
340+
const LCAO_Orbitals &orb,
341+
Grid_Driver& GridD,
342+
const int nks,
343+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
344+
const bool isstress);
345+
346+
template void LCAO_Deepks::cal_gdmx<std::complex<double>>(const std::vector<std::vector<std::complex<double>>>& dm,
347+
const UnitCell &ucell,
348+
const LCAO_Orbitals &orb,
349+
Grid_Driver& GridD,
350+
const int nks,
351+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
352+
const bool isstress);
353+
281354
#endif

0 commit comments

Comments
 (0)