Skip to content

Commit 23f6e19

Browse files
committed
Simplify some function for LCAO_deepks_io.
1 parent a542095 commit 23f6e19

File tree

4 files changed

+73
-434
lines changed

4 files changed

+73
-434
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -695,17 +695,14 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
695695
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
696696
if (PARAM.inp.deepks_scf)
697697
{
698-
699698
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
700699
file_sbase,
701700
ucell.omega,
702701
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
703702
}
704703
else
705704
{
706-
LCAO_deepks_io::save_npy_s(scs,
707-
file_sbase,
708-
ucell.omega,
705+
LCAO_deepks_io::save_npy_s(scs, file_sbase, ucell.omega,
709706
GlobalV::MY_RANK); // sbase = stot
710707
}
711708
}

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
9797

9898
torch::Tensor gvx;
9999
DeePKS_domain::cal_gvx(ucell.nat, inlmax, des_per_atom, inl_l, gevdm, gdmx, gvx);
100-
LCAO_deepks_io::save_npy_gvx(ucell.nat,
101-
des_per_atom,
102-
gvx,
103-
PARAM.globalv.global_out_dir,
104-
GlobalV::MY_RANK);
100+
const std::string file_gradvx = PARAM.globalv.global_out_dir + "deepks_gradvx.npy";
101+
LCAO_deepks_io::save_tensor2npy<double>(file_gradvx, gvx, my_rank);
105102

106103
if (PARAM.inp.deepks_out_unittest)
107104
{
@@ -124,11 +121,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
124121

125122
torch::Tensor gvepsl;
126123
DeePKS_domain::cal_gvepsl(ucell.nat, inlmax, des_per_atom, inl_l, gevdm, gdmepsl, gvepsl);
127-
LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
128-
des_per_atom,
129-
gvepsl,
130-
PARAM.globalv.global_out_dir,
131-
GlobalV::MY_RANK);
124+
const std::string file_gvepsl = PARAM.globalv.global_out_dir + "deepks_gvepsl.npy";
125+
LCAO_deepks_io::save_tensor2npy<double>(file_gvepsl, gvepsl, my_rank);
132126

133127
if (PARAM.inp.deepks_out_unittest)
134128
{
@@ -205,12 +199,9 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
205199
DeePKS_domain::cal_o_delta<TK, TH>(dm_bandgap, *h_delta, o_delta, *ParaV, nks);
206200

207201
// save obase and orbital_precalc
208-
LCAO_deepks_io::save_npy_orbital_precalc(nat,
209-
nks,
210-
des_per_atom,
211-
orbital_precalc,
212-
PARAM.globalv.global_out_dir,
213-
my_rank);
202+
const std::string file_orbpre = PARAM.globalv.global_out_dir + "deepks_orbpre.npy";
203+
LCAO_deepks_io::save_tensor2npy<double>(file_orbpre, orbital_precalc, my_rank);
204+
214205
const std::string file_obase = PARAM.globalv.global_out_dir + "deepks_obase.npy";
215206
std::vector<double> o_base(nks);
216207
for (int iks = 0; iks < nks; ++iks)
@@ -297,38 +288,21 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
297288
GridD,
298289
v_delta_precalc);
299290

300-
LCAO_deepks_io::save_npy_v_delta_precalc<TK>(nat,
301-
nks,
302-
nlocal,
303-
des_per_atom,
304-
v_delta_precalc,
305-
PARAM.globalv.global_out_dir,
306-
my_rank);
291+
const std::string file_vdpre = PARAM.globalv.global_out_dir + "deepks_vdpre.npy";
292+
LCAO_deepks_io::save_tensor2npy<TK>(file_vdpre, v_delta_precalc, my_rank);
307293
}
308294
else if (PARAM.inp.deepks_v_delta == 2) // v_delta_precalc storage method 2
309295
{
310296
torch::Tensor phialpha_out;
311297
DeePKS_domain::prepare_phialpha<
312298
TK>(nlocal, lmaxd, inlmax, nat, nks, kvec_d, phialpha, ucell, orb, *ParaV, GridD, phialpha_out);
313-
314-
LCAO_deepks_io::save_npy_phialpha<TK>(nat,
315-
nks,
316-
nlocal,
317-
inlmax,
318-
lmaxd,
319-
phialpha_out,
320-
PARAM.globalv.global_out_dir,
321-
my_rank);
299+
const std::string file_phialpha = PARAM.globalv.global_out_dir + "deepks_phialpha.npy";
300+
LCAO_deepks_io::save_tensor2npy<TK>(file_phialpha, phialpha_out, my_rank);
322301

323302
torch::Tensor gevdm_out;
324303
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
325-
326-
LCAO_deepks_io::save_npy_gevdm(nat,
327-
inlmax,
328-
lmaxd,
329-
gevdm_out,
330-
PARAM.globalv.global_out_dir,
331-
my_rank);
304+
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
305+
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, my_rank);
332306
}
333307
}
334308
else // deepks_scf == 0

0 commit comments

Comments
 (0)