Skip to content

Commit 9bf2533

Browse files
authored
Refactor: Remove global dependence of some functions in DeePKS. (#5778)
* Move cal_o_delta from GlobalC::ld to DeePKS_domain and remove variable o_delta in ld. * Remove F_delta in ld and lessen the tedious dimension in orbital related variables in DeePKS.
1 parent 5e8bc21 commit 9bf2533

18 files changed

+350
-357
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ OBJS_CELL=atom_pseudo.o\
191191

192192
OBJS_DEEPKS=LCAO_deepks.o\
193193
deepks_force.o\
194-
LCAO_deepks_odelta.o\
194+
deepks_orbital.o\
195195
LCAO_deepks_io.o\
196196
LCAO_deepks_mpi.o\
197197
LCAO_deepks_pdm.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class Force_LCAO
6363
ModuleBase::matrix& svnl_dbeta,
6464
ModuleBase::matrix& svl_dphi,
6565
#ifdef __DEEPKS
66+
ModuleBase::matrix& fvnl_dalpha,
6667
ModuleBase::matrix& svnl_dalpha,
6768
#endif
6869
typename TGint<T>::type& gint,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
8080
ModuleBase::matrix fewalds;
8181
ModuleBase::matrix fcc;
8282
ModuleBase::matrix fscc;
83+
#ifdef __DEEPKS
84+
ModuleBase::matrix fvnl_dalpha; // deepks
85+
#endif
8386

8487
fvl_dphi.create(nat, 3); // must do it now, update it later, noted by zhengdy
8588

@@ -93,6 +96,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
9396
fewalds.create(nat, 3);
9497
fcc.create(nat, 3);
9598
fscc.create(nat, 3);
99+
#ifdef __DEEPKS
100+
fvnl_dalpha.create(nat, 3); // deepks
101+
#endif
96102

97103
// calculate basic terms in Force, same method with PW base
98104
this->calForcePwPart(ucell,
@@ -172,6 +178,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
172178
svnl_dbeta,
173179
svl_dphi,
174180
#ifdef __DEEPKS
181+
fvnl_dalpha,
175182
svnl_dalpha,
176183
#endif
177184
gint_gamma,
@@ -454,7 +461,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
454461
// mohan add 2021-08-04
455462
if (PARAM.inp.deepks_scf)
456463
{
457-
fcs(iat, i) += GlobalC::ld.F_delta(iat, i);
464+
fcs(iat, i) += fvnl_dalpha(iat, i);
458465
}
459466
#endif
460467
// sum total force for correction
@@ -499,7 +506,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
499506
if (PARAM.inp.deepks_scf)
500507
{
501508
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
502-
LCAO_deepks_io::save_npy_f(fcs - GlobalC::ld.F_delta,
509+
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha,
503510
file_fbase,
504511
ucell.nat,
505512
GlobalV::MY_RANK); // Ry/Bohr, F_base
@@ -636,8 +643,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
636643
// caoyu add 2021-06-03
637644
if (PARAM.inp.deepks_scf)
638645
{
639-
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", GlobalC::ld.F_delta, true);
640-
// this->print_force("DeePKS FORCE", GlobalC::ld.F_delta, 1, ry);
646+
ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS FORCE", fvnl_dalpha, true);
641647
}
642648
#endif
643649
}
@@ -891,6 +897,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
891897
ModuleBase::matrix& svnl_dbeta,
892898
ModuleBase::matrix& svl_dphi,
893899
#if __DEEPKS
900+
ModuleBase::matrix& fvnl_dalpha,
894901
ModuleBase::matrix& svnl_dalpha,
895902
#endif
896903
Gint_Gamma& gint_gamma, // mohan add 2024-04-01
@@ -917,6 +924,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
917924
svnl_dbeta,
918925
svl_dphi,
919926
#if __DEEPKS
927+
fvnl_dalpha,
920928
svnl_dalpha,
921929
#endif
922930
gint_gamma,
@@ -944,6 +952,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
944952
ModuleBase::matrix& svnl_dbeta,
945953
ModuleBase::matrix& svl_dphi,
946954
#if __DEEPKS
955+
ModuleBase::matrix& fvnl_dalpha,
947956
ModuleBase::matrix& svnl_dalpha,
948957
#endif
949958
Gint_Gamma& gint_gamma,
@@ -969,6 +978,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
969978
svnl_dbeta,
970979
svl_dphi,
971980
#if __DEEPKS
981+
fvnl_dalpha,
972982
svnl_dalpha,
973983
#endif
974984
gint_k,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class Force_Stress_LCAO
9696
ModuleBase::matrix& svnl_dbeta,
9797
ModuleBase::matrix& svl_dphi,
9898
#if __DEEPKS
99+
ModuleBase::matrix& fvnl_dalpha,
99100
ModuleBase::matrix& svnl_dalpha,
100101
#endif
101102
Gint_Gamma& gint_gamma,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
188188
ModuleBase::matrix& svnl_dbeta,
189189
ModuleBase::matrix& svl_dphi,
190190
#ifdef __DEEPKS
191+
ModuleBase::matrix& fvnl_dalpha,
191192
ModuleBase::matrix& svnl_dalpha,
192193
#endif
193194
TGint<double>::type& gint,
@@ -246,15 +247,13 @@ void Force_LCAO<double>::ftable(const bool isforce,
246247
false /*reset dm to gint*/);
247248

248249
#ifdef __DEEPKS
250+
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
249251
if (PARAM.inp.deepks_scf)
250252
{
251-
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
252-
253253
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
254254
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
255255

256256
GlobalC::ld.cal_descriptor(ucell.nat);
257-
258257
GlobalC::ld.cal_gedm(ucell.nat);
259258

260259
const int nks = 1;
@@ -269,40 +268,9 @@ void Force_LCAO<double>::ftable(const bool isforce,
269268
GlobalC::ld.phialpha,
270269
GlobalC::ld.gedm,
271270
GlobalC::ld.inl_index,
272-
GlobalC::ld.F_delta,
271+
fvnl_dalpha,
273272
isstress,
274273
svnl_dalpha);
275-
276-
#ifdef __MPI
277-
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);
278-
279-
if (isstress)
280-
{
281-
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
282-
}
283-
#endif
284-
285-
if (PARAM.inp.deepks_out_unittest)
286-
{
287-
const int nks = 1; // 1 for gamma-only
288-
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);
289-
290-
GlobalC::ld.check_projected_dm();
291-
292-
GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
293-
294-
GlobalC::ld.check_gedm();
295-
296-
GlobalC::ld.cal_e_delta_band(dm_gamma, nks);
297-
298-
std::ofstream ofs("E_delta_bands.dat");
299-
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;
300-
301-
std::ofstream ofs1("E_delta.dat");
302-
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;
303-
304-
DeePKS_domain::check_f_delta(ucell.nat, GlobalC::ld.F_delta, svnl_dalpha);
305-
}
306274
}
307275
#endif
308276

@@ -312,14 +280,46 @@ void Force_LCAO<double>::ftable(const bool isforce,
312280
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
313281
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
314282
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
283+
#ifdef __DEEPKS
284+
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
285+
#endif
315286
}
316287
if (isstress)
317288
{
318289
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
319290
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
320291
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
321292
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
293+
#ifdef __DEEPKS
294+
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
295+
#endif
296+
}
297+
298+
#ifdef __DEEPKS
299+
// It seems these test should not all be here, should be moved in the future
300+
// Also, these test are not in multi-k case now
301+
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
302+
{
303+
const int nks = 1; // 1 for gamma-only
304+
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);
305+
306+
GlobalC::ld.check_projected_dm();
307+
308+
GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
309+
310+
GlobalC::ld.check_gedm();
311+
312+
GlobalC::ld.cal_e_delta_band(dm_gamma, nks);
313+
314+
std::ofstream ofs("E_delta_bands.dat");
315+
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;
316+
317+
std::ofstream ofs1("E_delta.dat");
318+
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;
319+
320+
DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
322321
}
322+
#endif
323323

324324
// delete DSloc_x, DSloc_y, DSloc_z
325325
// delete DHloc_fixed_x, DHloc_fixed_y, DHloc_fixed_z

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
282282
ModuleBase::matrix& svnl_dbeta,
283283
ModuleBase::matrix& svl_dphi,
284284
#ifdef __DEEPKS
285+
ModuleBase::matrix& fvnl_dalpha,
285286
ModuleBase::matrix& svnl_dalpha,
286287
#endif
287288
TGint<std::complex<double>>::type& gint,
@@ -363,17 +364,9 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
363364
GlobalC::ld.phialpha,
364365
GlobalC::ld.gedm,
365366
GlobalC::ld.inl_index,
366-
GlobalC::ld.F_delta,
367+
fvnl_dalpha,
367368
isstress,
368369
svnl_dalpha);
369-
370-
#ifdef __MPI
371-
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);
372-
if (isstress)
373-
{
374-
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
375-
}
376-
#endif
377370
}
378371
#endif
379372

@@ -386,13 +379,19 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
386379
Parallel_Reduce::reduce_pool(ftvnl_dphi.c, ftvnl_dphi.nr * ftvnl_dphi.nc);
387380
Parallel_Reduce::reduce_pool(fvnl_dbeta.c, fvnl_dbeta.nr * fvnl_dbeta.nc);
388381
Parallel_Reduce::reduce_pool(fvl_dphi.c, fvl_dphi.nr * fvl_dphi.nc);
382+
#ifdef __DEEPKS
383+
Parallel_Reduce::reduce_pool(fvnl_dalpha.c, fvnl_dalpha.nr * fvnl_dalpha.nc);
384+
#endif
389385
}
390386
if (isstress)
391387
{
392388
Parallel_Reduce::reduce_pool(soverlap.c, soverlap.nr * soverlap.nc);
393389
Parallel_Reduce::reduce_pool(stvnl_dphi.c, stvnl_dphi.nr * stvnl_dphi.nc);
394390
Parallel_Reduce::reduce_pool(svnl_dbeta.c, svnl_dbeta.nr * svnl_dbeta.nc);
395391
Parallel_Reduce::reduce_pool(svl_dphi.c, svl_dphi.nr * svl_dphi.nc);
392+
#ifdef __DEEPKS
393+
Parallel_Reduce::reduce_pool(svnl_dalpha.c, svnl_dalpha.nr * svnl_dalpha.nc);
394+
#endif
396395
}
397396

398397
ModuleBase::timer::tick("Force_LCAO", "ftable");

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ if(ENABLE_DEEPKS)
22
list(APPEND objects
33
LCAO_deepks.cpp
44
deepks_force.cpp
5-
LCAO_deepks_odelta.cpp
5+
deepks_orbital.cpp
66
LCAO_deepks_io.cpp
77
LCAO_deepks_mpi.cpp
88
LCAO_deepks_pdm.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,6 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
332332
}
333333
if (PARAM.inp.cal_force)
334334
{
335-
// init F_delta
336-
F_delta.create(nat, 3);
337335
if (PARAM.inp.deepks_out_labels)
338336
{
339337
this->init_gdmx(nat);
@@ -342,34 +340,24 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
342340
// gdmx is used only in calculating gvx
343341
}
344342

345-
if (PARAM.inp.deepks_bandgap)
346-
{
347-
// init o_delta
348-
o_delta.create(nks, 1);
349-
}
350-
351343
return;
352344
}
353345

354346
void LCAO_Deepks::init_orbital_pdm_shell(const int nks)
355347
{
356348

357-
this->orbital_pdm_shell = new double***[nks];
349+
this->orbital_pdm_shell = new double**[nks];
358350

359351
for (int iks = 0; iks < nks; iks++)
360352
{
361-
this->orbital_pdm_shell[iks] = new double**[1];
362-
for (int hl = 0; hl < 1; hl++)
353+
this->orbital_pdm_shell[iks] = new double*[this->inlmax];
354+
for (int inl = 0; inl < this->inlmax; inl++)
363355
{
364-
this->orbital_pdm_shell[iks][hl] = new double*[this->inlmax];
365-
366-
for (int inl = 0; inl < this->inlmax; inl++)
367-
{
368-
this->orbital_pdm_shell[iks][hl][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
369-
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][hl][inl],
370-
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
371-
}
356+
this->orbital_pdm_shell[iks][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
357+
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][inl],
358+
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
372359
}
360+
373361
}
374362

375363
return;
@@ -379,13 +367,9 @@ void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
379367
{
380368
for (int iks = 0; iks < nks; iks++)
381369
{
382-
for (int hl = 0; hl < 1; hl++)
370+
for (int inl = 0; inl < this->inlmax; inl++)
383371
{
384-
for (int inl = 0; inl < this->inlmax; inl++)
385-
{
386-
delete[] this->orbital_pdm_shell[iks][hl][inl];
387-
}
388-
delete[] this->orbital_pdm_shell[iks][hl];
372+
delete[] this->orbital_pdm_shell[iks][inl];
389373
}
390374
delete[] this->orbital_pdm_shell[iks];
391375
}

0 commit comments

Comments
 (0)