Skip to content

Commit 974a84b

Browse files
ErjieWuFisherd99
authored andcommitted
Add support for INPUT deepks_v_delta>0 in multi-k points DeePKS calculations (deepmodeling#5700)
1 parent 0f8f848 commit 974a84b

File tree

14 files changed

+810
-67
lines changed

14 files changed

+810
-67
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
208208
cal_gvx.o\
209209
cal_descriptor.o\
210210
v_delta_precalc.o\
211+
v_delta_precalc_k.o\
211212

212213

213214
OBJS_ELECSTAT=elecstate.o\

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10021002
#ifdef __DEEPKS
10031003
if (PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta)
10041004
{
1005-
DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc);
1005+
DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc, ik);
10061006
}
10071007
#endif
10081008
}

source/module_hamilt_lcao/hamilt_lcaodft/LCAO_allocate.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void divide_HS_in_frag(const bool isGamma, const UnitCell& ucell, Parallel_Orbit
2727
GlobalC::ld.init(orb,
2828
ucell.nat,
2929
ucell.ntype,
30+
nks,
3031
pv,
3132
na);
3233

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(ENABLE_DEEPKS)
2020
cal_gvx.cpp
2121
cal_descriptor.cpp
2222
v_delta_precalc.cpp
23+
v_delta_precalc_k.cpp
2324
)
2425

2526
add_library(

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ void LCAO_Deepks::init(
7474
const LCAO_Orbitals& orb,
7575
const int nat,
7676
const int ntype,
77+
const int nks,
7778
const Parallel_Orbitals& pv_in,
7879
std::vector<int> na)
7980
{
@@ -97,13 +98,13 @@ void LCAO_Deepks::init(
9798
this->nmaxd = nm;
9899

99100
GlobalV::ofs_running << " lmax of descriptor = " << this->lmaxd << std::endl;
100-
GlobalV::ofs_running << " nmax of descriptor= " << nmaxd << std::endl;
101+
GlobalV::ofs_running << " nmax of descriptor = " << nmaxd << std::endl;
101102

102103
int pdm_size = 0;
103104
this->inlmax = tot_inl;
104105
if(!PARAM.inp.deepks_equiv)
105106
{
106-
GlobalV::ofs_running << " total basis (all atoms) for descriptor= " << std::endl;
107+
GlobalV::ofs_running << " total basis (all atoms) for descriptor = " << std::endl;
107108

108109
//init pdm**
109110
pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
@@ -150,6 +151,15 @@ void LCAO_Deepks::init(
150151
int nloc=this->pv->nloc;
151152
this->h_mat.resize(nloc,0.0);
152153
}
154+
else
155+
{
156+
int nloc=this->pv->nloc;
157+
this->h_mat_k.resize(nks);
158+
for (int ik = 0; ik < nks; ik++)
159+
{
160+
this->h_mat_k[ik].resize(nloc,std::complex<double>(0.0,0.0));
161+
}
162+
}
153163
}
154164

155165
return;
@@ -431,27 +441,51 @@ void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
431441

432442
void LCAO_Deepks::init_v_delta_pdm_shell(const int nks,const int nlocal)
433443
{
434-
435-
this->v_delta_pdm_shell = new double**** [nks];
436-
437444
const int mn_size=(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1);
438-
for (int iks=0; iks<nks; iks++)
439-
{
440-
this->v_delta_pdm_shell[iks] = new double*** [nlocal];
445+
if (nks==1){
446+
this->v_delta_pdm_shell = new double**** [nks];
447+
for (int iks=0; iks<nks; iks++)
448+
{
449+
this->v_delta_pdm_shell[iks] = new double*** [nlocal];
441450

442-
for (int mu=0; mu<nlocal; mu++)
451+
for (int mu=0; mu<nlocal; mu++)
452+
{
453+
this->v_delta_pdm_shell[iks][mu] = new double** [nlocal];
454+
455+
for (int nu=0; nu<nlocal; nu++)
456+
{
457+
this->v_delta_pdm_shell[iks][mu][nu] = new double* [this->inlmax];
458+
459+
for(int inl = 0; inl < this->inlmax; inl++)
460+
{
461+
this->v_delta_pdm_shell[iks][mu][nu][inl] = new double [mn_size];
462+
ModuleBase::GlobalFunc::ZEROS(v_delta_pdm_shell[iks][mu][nu][inl], mn_size);
463+
}
464+
}
465+
}
466+
}
467+
}
468+
else
469+
{
470+
this->v_delta_pdm_shell_complex = new std::complex<double>**** [nks];
471+
for (int iks=0; iks<nks; iks++)
443472
{
444-
this->v_delta_pdm_shell[iks][mu] = new double** [nlocal];
473+
this->v_delta_pdm_shell_complex[iks] = new std::complex<double>*** [nlocal];
445474

446-
for (int nu=0; nu<nlocal; nu++)
475+
for (int mu=0; mu<nlocal; mu++)
447476
{
448-
this->v_delta_pdm_shell[iks][mu][nu] = new double* [this->inlmax];
477+
this->v_delta_pdm_shell_complex[iks][mu] = new std::complex<double>** [nlocal];
449478

450-
for(int inl = 0; inl < this->inlmax; inl++)
479+
for (int nu=0; nu<nlocal; nu++)
451480
{
452-
this->v_delta_pdm_shell[iks][mu][nu][inl] = new double [mn_size];
453-
ModuleBase::GlobalFunc::ZEROS(v_delta_pdm_shell[iks][mu][nu][inl], mn_size);
454-
}
481+
this->v_delta_pdm_shell_complex[iks][mu][nu] = new std::complex<double>* [this->inlmax];
482+
483+
for(int inl = 0; inl < this->inlmax; inl++)
484+
{
485+
this->v_delta_pdm_shell_complex[iks][mu][nu][inl] = new std::complex<double> [mn_size];
486+
ModuleBase::GlobalFunc::ZEROS(v_delta_pdm_shell_complex[iks][mu][nu][inl], mn_size);
487+
}
488+
}
455489
}
456490
}
457491
}
@@ -461,23 +495,46 @@ void LCAO_Deepks::init_v_delta_pdm_shell(const int nks,const int nlocal)
461495

462496
void LCAO_Deepks::del_v_delta_pdm_shell(const int nks,const int nlocal)
463497
{
464-
for (int iks=0; iks<nks; iks++)
498+
if (nks==1)
465499
{
466-
for (int mu=0; mu<nlocal; mu++)
500+
for (int iks=0; iks<nks; iks++)
467501
{
468-
for (int nu=0; nu<nlocal; nu++)
502+
for (int mu=0; mu<nlocal; mu++)
469503
{
470-
for (int inl = 0;inl < this->inlmax; inl++)
504+
for (int nu=0; nu<nlocal; nu++)
471505
{
472-
delete[] this->v_delta_pdm_shell[iks][mu][nu][inl];
506+
for (int inl = 0;inl < this->inlmax; inl++)
507+
{
508+
delete[] this->v_delta_pdm_shell[iks][mu][nu][inl];
509+
}
510+
delete[] this->v_delta_pdm_shell[iks][mu][nu];
511+
}
512+
delete[] this->v_delta_pdm_shell[iks][mu];
513+
}
514+
delete[] this->v_delta_pdm_shell[iks];
515+
}
516+
delete[] this->v_delta_pdm_shell;
517+
}
518+
else
519+
{
520+
for (int iks=0; iks<nks; iks++)
521+
{
522+
for (int mu=0; mu<nlocal; mu++)
523+
{
524+
for (int nu=0; nu<nlocal; nu++)
525+
{
526+
for (int inl = 0;inl < this->inlmax; inl++)
527+
{
528+
delete[] this->v_delta_pdm_shell_complex[iks][mu][nu][inl];
529+
}
530+
delete[] this->v_delta_pdm_shell_complex[iks][mu][nu];
473531
}
474-
delete[] this->v_delta_pdm_shell[iks][mu][nu];
532+
delete[] this->v_delta_pdm_shell_complex[iks][mu];
475533
}
476-
delete[] this->v_delta_pdm_shell[iks][mu];
534+
delete[] this->v_delta_pdm_shell_complex[iks];
477535
}
478-
delete[] this->v_delta_pdm_shell[iks];
536+
delete[] this->v_delta_pdm_shell_complex;
479537
}
480-
delete[] this->v_delta_pdm_shell;
481538

482539
return;
483540
}

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ class LCAO_Deepks
5454
///\rho_{HL} = c_{L, \mu}c_{L,\nu} - c_{H, \mu}c_{H,\nu} \f$ (for gamma_only)
5555
ModuleBase::matrix o_delta;
5656

57-
///(Unit: Ry) Hamiltonian matrix
57+
///(Unit: Ry) Hamiltonian matrix in k space
58+
/// for gamma only
5859
std::vector<double> h_mat;
60+
/// for multi-k
61+
std::vector<std::vector<std::complex<double>>> h_mat_k;
5962

6063
/// Correction term to the Hamiltonian matrix: \f$\langle\psi|V_\delta|\psi\rangle\f$ (for gamma only)
6164
std::vector<double> H_V_delta;
@@ -159,13 +162,14 @@ class LCAO_Deepks
159162
// dD/dX, tensor form of gdmx
160163
std::vector<torch::Tensor> gdmr_vector;
161164

162-
// orbital_pdm_shell:[1,Inl,nm*nm]; \langle \phi_\mu|\alpha\rangle\langle\alpha|\phi_\nu\rnalge
165+
// orbital_pdm_shell:[1,Inl,nm*nm]; \langle \phi_\mu|\alpha\rangle\langle\alpha|\phi_\nu\ranlge
163166
double**** orbital_pdm_shell;
164167
// orbital_precalc:[1,NAt,NDscrpt]; gvdm*orbital_pdm_shell
165168
torch::Tensor orbital_precalc_tensor;
166169

167170
// v_delta_pdm_shell[nks,nlocal,nlocal,Inl,nm*nm] = overlap * overlap
168171
double***** v_delta_pdm_shell;
172+
std::complex<double>***** v_delta_pdm_shell_complex; // for multi-k
169173
// v_delta_precalc[nks,nlocal,nlocal,NAt,NDscrpt] = gvdm * v_delta_pdm_shell;
170174
torch::Tensor v_delta_precalc_tensor;
171175
//for v_delta==2 , new v_delta_precalc storage method
@@ -220,6 +224,7 @@ class LCAO_Deepks
220224
void init(const LCAO_Orbitals& orb,
221225
const int nat,
222226
const int ntype,
227+
const int nks,
223228
const Parallel_Orbitals& pv_in,
224229
std::vector<int> na);
225230

@@ -437,12 +442,15 @@ class LCAO_Deepks
437442
// 11. cal_orbital_precalc_k : orbital_precalc is usted for training with orbital label,
438443
// for multi-k case, which equals gvdm * orbital_pdm_shell,
439444
// orbital_pdm_shell[1,Inl,nm*nm] = dm_hl_k * overlap * overlap
440-
//12. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
445+
// 12. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
441446
// which equals gvdm * v_delta_pdm_shell,
442447
// v_delta_pdm_shell = overlap * overlap
443-
//13. check_v_delta_precalc : check v_delta_precalc
444-
//14. prepare_psialpha : prepare psialpha for outputting npy file
445-
//15. prepare_gevdm : prepare gevdm for outputting npy file
448+
// 13. cal_v_delta_precalc_k : v_delta_precalc is used for training with v_delta label,
449+
// for multi-k case, which equals ???
450+
// ???
451+
// 14. check_v_delta_precalc : check v_delta_precalc
452+
// 15. prepare_psialpha : prepare psialpha for outputting npy file
453+
// 16. prepare_gevdm : prepare gevdm for outputting npy file
446454

447455
public:
448456
/// Calculates descriptors
@@ -500,6 +508,14 @@ class LCAO_Deepks
500508
const LCAO_Orbitals &orb,
501509
Grid_Driver &GridD);
502510

511+
void cal_v_delta_precalc_k(const int nlocal,
512+
const int nat,
513+
const int nks,
514+
const std::vector<ModuleBase::Vector3<double>> &kvec_d,
515+
const UnitCell &ucell,
516+
const LCAO_Orbitals &orb,
517+
Grid_Driver &GridD);
518+
503519
void check_v_delta_precalc(const int nat, const int nks,const int nlocal);
504520

505521
// prepare psialpha for outputting npy file
@@ -508,6 +524,13 @@ class LCAO_Deepks
508524
const UnitCell &ucell,
509525
const LCAO_Orbitals &orb,
510526
Grid_Driver &GridD);
527+
void prepare_psialpha_k(const int nlocal,
528+
const int nat,
529+
const int nks,
530+
const std::vector<ModuleBase::Vector3<double>> &kvec_d,
531+
const UnitCell &ucell,
532+
const LCAO_Orbitals &orb,
533+
Grid_Driver &GridD);
511534
void check_vdp_psialpha(const int nat, const int nks, const int nlocal);
512535

513536
// prepare gevdm for outputting npy file

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ void LCAO_Deepks_Interface::out_deepks_labels(const double& etot,
291291

292292
if (PARAM.inp.deepks_scf)
293293
{
294-
int nocc = PARAM.inp.nelec / 2;
294+
int nocc = PARAM.inp.nelec / 2; // redundant!
295295
ModuleBase::matrix wg_hl;
296296
wg_hl.create(nks, PARAM.inp.nbands);
297297
std::vector<std::vector<ModuleBase::ComplexMatrix>> dm_bandgap_k;
@@ -333,7 +333,94 @@ void LCAO_Deepks_Interface::out_deepks_labels(const double& etot,
333333
} // end bandgap label
334334
if(deepks_v_delta)
335335
{
336-
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "V_delta label has not been developed for multi-k now!");
336+
std::vector<ModuleBase::ComplexMatrix> h_tot(nks);
337+
for (int ik = 0; ik < nks; ik++)
338+
{
339+
h_tot[ik].create(nlocal, nlocal);
340+
}
341+
342+
DeePKS_domain::collect_h_mat(*ParaV, ld->h_mat_k,h_tot,nlocal,nks);
343+
344+
const std::string file_htot = PARAM.globalv.global_out_dir + "deepks_htot.npy";
345+
LCAO_deepks_io::save_npy_h(h_tot, file_htot, nlocal, nks, my_rank);
346+
347+
if(PARAM.inp.deepks_scf)
348+
{
349+
std::vector<ModuleBase::ComplexMatrix> v_delta(nks);
350+
std::vector<ModuleBase::ComplexMatrix> hbase(nks);
351+
for (int ik = 0; ik < nks; ik++)
352+
{
353+
v_delta[ik].create(nlocal, nlocal);
354+
hbase[ik].create(nlocal, nlocal);
355+
}
356+
DeePKS_domain::collect_h_mat(*ParaV, ld->H_V_delta_k,v_delta,nlocal,nks);
357+
358+
const std::string file_hbase = PARAM.globalv.global_out_dir + "deepks_hbase.npy";
359+
for (int ik = 0; ik < nks; ik++)
360+
{
361+
hbase[ik] = h_tot[ik] - v_delta[ik];
362+
}
363+
LCAO_deepks_io::save_npy_h(hbase, file_hbase, nlocal, nks, my_rank);
364+
365+
const std::string file_vdelta = PARAM.globalv.global_out_dir + "deepks_vdelta.npy";
366+
LCAO_deepks_io::save_npy_h(v_delta, file_vdelta, nlocal, nks, my_rank);
367+
368+
if(deepks_v_delta==1)//v_delta_precalc storage method 1
369+
{
370+
ld->cal_v_delta_precalc_k(nlocal,
371+
nat,
372+
nks,
373+
kvec_d,
374+
ucell,
375+
orb,
376+
GridD);
377+
378+
LCAO_deepks_io::save_npy_v_delta_precalc(
379+
nat,
380+
nks,
381+
nlocal,
382+
ld->des_per_atom,
383+
ld->v_delta_precalc_tensor,
384+
PARAM.globalv.global_out_dir,
385+
my_rank);
386+
387+
}
388+
else if(deepks_v_delta==2)//v_delta_precalc storage method 2
389+
{
390+
ld->prepare_psialpha_k(nlocal,
391+
nat,
392+
nks,
393+
kvec_d,
394+
ucell,
395+
orb,
396+
GridD);
397+
398+
LCAO_deepks_io::save_npy_psialpha(nat,
399+
nks,
400+
nlocal,
401+
ld->inlmax,
402+
ld->lmaxd,
403+
ld->psialpha_tensor,
404+
PARAM.globalv.global_out_dir,
405+
my_rank);
406+
407+
ld->prepare_gevdm(
408+
nat,
409+
orb);
410+
411+
LCAO_deepks_io::save_npy_gevdm(nat,
412+
ld->inlmax,
413+
ld->lmaxd,
414+
ld->gevdm_tensor,
415+
PARAM.globalv.global_out_dir,
416+
my_rank);
417+
}
418+
}
419+
else //deepks_scf == 0
420+
{
421+
const std::string file_hbase = PARAM.globalv.global_out_dir + "deepks_hbase.npy";
422+
LCAO_deepks_io::save_npy_h(h_tot, file_hbase, nlocal, nks, my_rank);
423+
}
337424
}
338425
} // end deepks_out_labels
339426

0 commit comments

Comments
 (0)