Skip to content

Commit 2486f99

Browse files
authored
Merge pull request #884 from ouqi0711/develop
DeePKS orbital (bandgap) label for multi-k case
2 parents c43c3e4 + 702cb01 commit 2486f99

File tree

9 files changed

+314
-76
lines changed

9 files changed

+314
-76
lines changed

source/module_deepks/LCAO_deepks.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ class LCAO_Deepks
6767
///(Unit: Ry/Bohr) Total Force due to the DeePKS correction term \f$E_{\delta}\f$
6868
ModuleBase::matrix F_delta;
6969

70+
//k index of HOMO for multi-k bandgap label. QO added 2022-01-24
71+
int h_ind = 0;
72+
73+
//k index of LUMO for multi-k bandgap label. QO added 2022-01-24
74+
int l_ind = 0;
75+
7076
//-------------------
7177
// private variables
7278
//-------------------
@@ -409,6 +415,9 @@ class LCAO_Deepks
409415
//9. cal_orbital_precalc : orbital_precalc is usted for training with orbital label,
410416
// which equals gvdm * orbital_pdm_shell,
411417
// orbital_pdm_shells[1,Inl,nm*nm] = dm_hl * overlap * overlap
418+
//10. cal_orbital_precalc_k : orbital_precalc is usted for training with orbital label,
419+
// for multi-k case, which equals gvdm * orbital_pdm_shell,
420+
// orbital_pdm_shells[1,Inl,nm*nm] = dm_hl_k * overlap * overlap
412421

413422
public:
414423

@@ -444,6 +453,17 @@ class LCAO_Deepks
444453
const LCAO_Orbitals &orb,
445454
Grid_Driver &GridD,
446455
const Parallel_Orbitals &ParaO);
456+
457+
//calculates orbital_precalc for multi-k case
458+
void cal_orbital_precalc_k(const std::vector<ModuleBase::ComplexMatrix>& dm_hl/**<[in] density matrix*/,
459+
const int nat,
460+
const int nks,
461+
const std::vector<ModuleBase::Vector3<double>> &kvec_d,
462+
const UnitCell_pseudo &ucell,
463+
const LCAO_Orbitals &orb,
464+
Grid_Driver &GridD,
465+
const Parallel_Orbitals &ParaO);
466+
447467

448468
private:
449469
void cal_gvdm(const int nat);

source/module_deepks/LCAO_deepks_torch.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "LCAO_deepks.h"
2727
#include "../src_parallel/parallel_reduce.h"
28+
#include "../module_base/constants.h"
2829

2930
//calculates descriptors from projected density matrices
3031
void LCAO_Deepks::cal_descriptor(void)
@@ -502,4 +503,183 @@ void LCAO_Deepks::cal_orbital_precalc(const std::vector<ModuleBase::matrix> &dm_
502503
return;
503504
}
504505

506+
// calculates orbital_precalc[1,NAt,NDscrpt] = gvdm * orbital_pdm_shells for multi-k case;
507+
// orbital_pdm_shells[1,Inl,nm*nm] = dm_hl_k * overlap * overlap;
508+
void LCAO_Deepks::cal_orbital_precalc_k(const std::vector<ModuleBase::ComplexMatrix> &dm_hl_k,
509+
const int nat,
510+
const int nks,
511+
const std::vector<ModuleBase::Vector3<double>> &kvec_d,
512+
const UnitCell_pseudo &ucell,
513+
const LCAO_Orbitals &orb,
514+
Grid_Driver &GridD,
515+
const Parallel_Orbitals &ParaO)
516+
{
517+
ModuleBase::TITLE("LCAO_Deepks", "calc_orbital_precalc_k");
518+
519+
this->cal_gvdm(nat);
520+
const double Rcut_Alpha = orb.Alpha[0].getRcut();
521+
this->init_orbital_pdm_shell();
522+
523+
for (int T0 = 0; T0 < ucell.ntype; T0++)
524+
{
525+
Atom* atom0 = &ucell.atoms[T0];
526+
527+
for (int I0 =0; I0< atom0->na; I0++)
528+
{
529+
const int iat = ucell.itia2iat(T0,I0);
530+
const ModuleBase::Vector3<double> tau0 = atom0->tau[I0];
531+
GridD.Find_atom(ucell, atom0->tau[I0] ,T0, I0);
532+
533+
for (int ad1=0; ad1<GridD.getAdjacentNum()+1 ; ++ad1)
534+
{
535+
const int T1 = GridD.getType(ad1);
536+
const int I1 = GridD.getNatom(ad1);
537+
const int ibt1 = ucell.itia2iat(T1,I1);
538+
const int start1 = ucell.itiaiw2iwt(T1, I1, 0);
539+
const ModuleBase::Vector3<double> tau1 = GridD.getAdjacentTau(ad1);
540+
541+
const Atom* atom1 = &ucell.atoms[T1];
542+
const int nw1_tot = atom1->nw*GlobalV::NPOL;
543+
const double Rcut_AO1 = orb.Phi[T1].getRcut();
544+
545+
ModuleBase::Vector3<double> dR1(GridD.getBox(ad1).x, GridD.getBox(ad1).y, GridD.getBox(ad1).z);
546+
547+
for (int ad2=0; ad2 < GridD.getAdjacentNum()+1 ; ad2++)
548+
{
549+
const int T2 = GridD.getType(ad2);
550+
const int I2 = GridD.getNatom(ad2);
551+
const int ibt2 = ucell.itia2iat(T2,I2);
552+
const int start2 = ucell.itiaiw2iwt(T2, I2, 0);
553+
const ModuleBase::Vector3<double> tau2 = GridD.getAdjacentTau(ad2);
554+
const Atom* atom2 = &ucell.atoms[T2];
555+
const int nw2_tot = atom2->nw*GlobalV::NPOL;
556+
ModuleBase::Vector3<double> dR2(GridD.getBox(ad2).x, GridD.getBox(ad2).y, GridD.getBox(ad2).z);
557+
558+
const double Rcut_AO2 = orb.Phi[T2].getRcut();
559+
const double dist1 = (tau1-tau0).norm() * ucell.lat0;
560+
const double dist2 = (tau2-tau0).norm() * ucell.lat0;
561+
562+
if (dist1 > Rcut_Alpha + Rcut_AO1 || dist2 > Rcut_Alpha + Rcut_AO2)
563+
{
564+
continue;
565+
}
566+
567+
for (int iw1=0; iw1<nw1_tot; ++iw1)
568+
{
569+
const int iw1_all = start1 + iw1; // this is \mu
570+
const int iw1_local = ParaO.trace_loc_col[iw1_all];
571+
if(iw1_local < 0)continue;
572+
573+
for (int iw2=0; iw2<nw2_tot; ++iw2)
574+
{
575+
const int iw2_all = start2 + iw2; // this is \nu
576+
const int iw2_local = ParaO.trace_loc_row[iw2_all];
577+
if(iw2_local < 0)continue;
578+
double dm_current;
579+
std::complex<double> tmp = 0.0;
580+
for(int ik=0;ik<nks;ik++)
581+
{
582+
const double arg = - (kvec_d[ik] * (dR2-dR1) ) * ModuleBase::TWO_PI;
583+
const std::complex<double> kphase = std::complex <double> ( cos(arg), sin(arg) );
584+
tmp += dm_hl_k[ik](iw1_local, iw2_local) * kphase;
585+
}
586+
dm_current=tmp.real();
587+
588+
key_tuple key_1(ibt1,dR1.x,dR1.y,dR1.z);
589+
key_tuple key_2(ibt2,dR2.x,dR2.y,dR2.z);
590+
std::vector<double> nlm1 = this->nlm_save_k[iat][key_1][iw1_all][0];
591+
std::vector<double> nlm2 = this->nlm_save_k[iat][key_2][iw2_all][0];
592+
assert(nlm1.size()==nlm2.size());
593+
594+
int ib=0;
595+
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax();++L0)
596+
{
597+
for (int N0 = 0;N0 < orb.Alpha[0].getNchi(L0);++N0)
598+
{
599+
const int inl = this->inl_index[T0](I0, L0, N0);
600+
const int nm = 2*L0+1;
601+
602+
for (int m1=0; m1<nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
603+
{
604+
for (int m2=0; m2<nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
605+
{
606+
607+
orbital_pdm_shell[0][inl][m1*nm+m2] += dm_current*nlm1[ib+m1]*nlm2[ib+m2];
608+
609+
}
610+
}
611+
ib+=nm;
612+
}
613+
}
614+
615+
}//iw2
616+
}//iw1
617+
}//ad2
618+
}//ad1
619+
620+
}
621+
}
622+
#ifdef __MPI
623+
for(int inl = 0; inl < this->inlmax; inl++)
624+
{
625+
Parallel_Reduce::reduce_double_all(this->orbital_pdm_shell[0][inl],(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
626+
}
627+
#endif
628+
629+
// transfer orbital_pdm_shell to orbital_pdm_shell_vector
630+
631+
632+
int nlmax = this->inlmax/nat;
633+
634+
std::vector<torch::Tensor> orbital_pdm_shell_vector;
635+
636+
for(int nl = 0; nl < nlmax; ++nl)
637+
{
638+
std::vector<torch::Tensor> iammv;
639+
for(int hl=0; hl<1; ++hl)
640+
{
641+
std::vector<torch::Tensor> ammv;
642+
for (int iat=0; iat<nat; ++iat)
643+
{
644+
int inl = iat*nlmax+nl;
645+
int nm = 2*this->inl_l[inl]+1;
646+
std::vector<double> mmv;
647+
648+
for (int m1=0; m1<nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
649+
{
650+
for (int m2=0; m2<nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
651+
{
652+
mmv.push_back(this->orbital_pdm_shell[hl][inl][m1*nm+m2]);
653+
}
654+
655+
}
656+
torch::Tensor mm = torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64) ).reshape({nm, nm}); //nm*nm
657+
ammv.push_back(mm);
658+
}
659+
660+
torch::Tensor amm = torch::stack(ammv, 0);
661+
iammv.push_back(amm);
662+
}
663+
664+
torch::Tensor iamm = torch::stack(iammv, 0); //inl*nm*nm
665+
orbital_pdm_shell_vector.push_back(iamm);
666+
}
667+
668+
669+
assert(orbital_pdm_shell_vector.size() == nlmax);
670+
671+
672+
//einsum for each nl:
673+
std::vector<torch::Tensor> orbital_precalc_vector;
674+
for (int nl = 0; nl<nlmax; ++nl)
675+
{
676+
orbital_precalc_vector.push_back(at::einsum("iamn, avmn->iav", {orbital_pdm_shell_vector[nl], this->gevdm_vector[nl]}));
677+
}
678+
679+
this->orbital_precalc_tensor = torch::cat(orbital_precalc_vector, -1);
680+
681+
this->del_orbital_pdm_shell();
682+
return;
683+
}
684+
505685
#endif

0 commit comments

Comments
 (0)