@@ -39,6 +39,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
3939 // These variables are frequently used in the following code
4040 const int inlmax = orb.Alpha [0 ].getTotal_nchi () * nat;
4141 const int lmaxd = orb.get_lmax_d ();
42+ const int nmaxd = ld->nmaxd ;
4243
4344 const int des_per_atom = ld->des_per_atom ;
4445 const int * inl_l = ld->inl_l ;
@@ -49,11 +50,56 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
4950 bool init_pdm = ld->init_pdm ;
5051 double E_delta = ld->E_delta ;
5152 double e_delta_band = ld->e_delta_band ;
52- double ** gedm = ld->gedm ;
5353
5454 const int my_rank = GlobalV::MY_RANK;
5555 const int nspin = PARAM.inp .nspin ;
5656
57+ // Note : update PDM and all other quantities with the current dm
58+ // DeePKS PDM and descriptor
59+ if (PARAM.inp .deepks_out_labels || PARAM.inp .deepks_scf )
60+ {
61+ // this part is for integrated test of deepks
62+ // so it is printed no matter even if deepks_out_labels is not used
63+ DeePKS_domain::cal_pdm<TK>
64+ (init_pdm, inlmax, lmaxd, inl_l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
65+
66+ DeePKS_domain::check_pdm (inlmax, inl_l, pdm); // print out the projected dm for NSCF calculaiton
67+
68+ std::vector<torch::Tensor> descriptor;
69+ DeePKS_domain::cal_descriptor (nat, inlmax, inl_l, pdm, descriptor,
70+ des_per_atom); // final descriptor
71+ DeePKS_domain::check_descriptor (inlmax, des_per_atom, inl_l, ucell, PARAM.globalv .global_out_dir , descriptor);
72+
73+ if (PARAM.inp .deepks_out_labels )
74+ {
75+ LCAO_deepks_io::save_npy_d (nat,
76+ des_per_atom,
77+ inlmax,
78+ inl_l,
79+ PARAM.inp .deepks_equiv ,
80+ descriptor,
81+ PARAM.globalv .global_out_dir ,
82+ GlobalV::MY_RANK); // libnpy needed
83+ }
84+
85+ if (PARAM.inp .deepks_scf )
86+ {
87+ // update E_delta and gedm
88+ // new gedm is also useful in cal_f_delta, so it should be ld->gedm
89+ DeePKS_domain::cal_edelta_gedm (nat,
90+ lmaxd,
91+ nmaxd,
92+ inlmax,
93+ des_per_atom,
94+ inl_l,
95+ descriptor,
96+ pdm,
97+ ld->model_deepks ,
98+ ld->gedm ,
99+ E_delta);
100+ }
101+ }
102+
57103 // Used for deepks_bandgap == 1 and deepks_v_delta > 0
58104 std::vector<std::vector<TK>>* h_delta = nullptr ;
59105 if constexpr (std::is_same<TK, double >::value)
@@ -65,7 +111,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
65111 h_delta = &ld->H_V_delta_k ;
66112 }
67113
68- // calculating deepks correction to bandgap and save the results
114+ // calculating deepks correction and save the results
69115 if (PARAM.inp .deepks_out_labels )
70116 {
71117 // Used for deepks_scf == 1
@@ -317,38 +363,6 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
317363
318364 } // end deepks_out_labels
319365
320- // DeePKS PDM and descriptor
321- if (PARAM.inp .deepks_out_labels || PARAM.inp .deepks_scf )
322- {
323- // this part is for integrated test of deepks
324- // so it is printed no matter even if deepks_out_labels is not used
325- // when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
326- if (!PARAM.inp .deepks_scf )
327- {
328- DeePKS_domain::cal_pdm<
329- TK>(init_pdm, inlmax, lmaxd, inl_l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
330- }
331-
332- DeePKS_domain::check_pdm (inlmax, inl_l, pdm); // print out the projected dm for NSCF calculaiton
333-
334- std::vector<torch::Tensor> descriptor;
335- DeePKS_domain::cal_descriptor (nat, inlmax, inl_l, pdm, descriptor,
336- des_per_atom); // final descriptor
337- DeePKS_domain::check_descriptor (inlmax, des_per_atom, inl_l, ucell, PARAM.globalv .global_out_dir , descriptor);
338-
339- if (PARAM.inp .deepks_out_labels )
340- {
341- LCAO_deepks_io::save_npy_d (nat,
342- des_per_atom,
343- inlmax,
344- inl_l,
345- PARAM.inp .deepks_equiv ,
346- descriptor,
347- PARAM.globalv .global_out_dir ,
348- GlobalV::MY_RANK); // libnpy needed
349- }
350- }
351-
352366 // / print out deepks information to the screen
353367 if (PARAM.inp .deepks_scf )
354368 {
@@ -361,7 +375,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
361375 {
362376 LCAO_deepks_io::print_dm (nks, PARAM.globalv .nlocal , ParaV->nrow , dm->get_DMK_vector ());
363377
364- DeePKS_domain::check_gedm (inlmax, inl_l, gedm);
378+ DeePKS_domain::check_gedm (inlmax, inl_l, ld-> gedm );
365379
366380 std::ofstream ofs (" E_delta_bands.dat" );
367381 ofs << std::setprecision (10 ) << e_delta_band;
0 commit comments