Skip to content

Commit a0840ed

Browse files
sunliang98mohanchenCopilot
authored
Fix: Fix several bugs in KEDF get_energy functions (#7169)
* Fix: Fix the get_energy function of several KEDFs * Fix: Fix generate_descriptor * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Mohan Chen <mohanchen@pku.edu.cn> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bf3fb24 commit a0840ed

8 files changed

Lines changed: 12 additions & 8 deletions

File tree

source/source_estate/module_pot/pot_ml_exx.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
6565
if (this->descriptor_type[i] == "gamma") feg_inpt[i] = 1.;
6666
}
6767

68-
if (PARAM.inp.of_ml_feg == 1)
68+
if (PARAM.inp.of_ml_feg == 1)
69+
{
6970
this->feg_net_F = torch::softplus(this->nn->forward(feg_inpt)).to(this->device_CPU).contiguous().data_ptr<double>()[0];
71+
}
7072
else
7173
{
7274
this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr<double>()[0];

source/source_io/module_ml/write_mlkedf_descriptors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void Write_MLKEDF_Descriptors::generate_descriptor(
171171

172172
// p
173173
this->cal_tool->getP(prho, pw_rho, nablaRho, container);
174-
npy::SaveArrayAsNumpy("p.npy", false, 1, cshape, container);
174+
npy::SaveArrayAsNumpy(out_dir + "/p.npy", false, 1, cshape, container);
175175

176176
for (int ik = 0; ik < this->cal_tool->nkernel; ++ik)
177177
{

source/source_pw/module_ofdft/kedf_lkt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ double KEDF_LKT::get_energy(const double* const* prho, ModulePW::PW_Basis* pw_rh
5252
}
5353
delete[] nabla_rho;
5454

55-
return energy;
55+
return this->lkt_energy;
5656
}
5757

5858
/**

source/source_pw/module_ofdft/kedf_ml.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ void KEDF_ML::set_para(
8787
if (this->descriptor_type[i] == "gamma") feg_inpt[i] = 1.;
8888
}
8989

90-
if (PARAM.inp.of_ml_feg == 1)
90+
if (PARAM.inp.of_ml_feg == 1)
91+
{
9192
this->feg_net_F = torch::softplus(this->nn->forward(feg_inpt)).to(this->device_CPU).contiguous().data_ptr<double>()[0];
93+
}
9294
else
9395
{
9496
this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr<double>()[0];

source/source_pw/module_ofdft/kedf_tf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ double KEDF_TF::get_energy(const double* const* prho)
4343
}
4444
this->tf_energy = energy;
4545
Parallel_Reduce::reduce_all(this->tf_energy);
46-
return energy;
46+
return this->tf_energy;
4747
}
4848

4949
/**

source/source_pw/module_ofdft/kedf_vw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ double KEDF_vW::get_energy(double** pphi, ModulePW::PW_Basis* pw_rho)
6969
delete[] tempPhi;
7070
delete[] LapPhi;
7171

72-
return energy;
72+
return this->vw_energy;
7373
}
7474

7575
/**

source/source_pw/module_ofdft/kedf_wt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ double KEDF_WT::get_energy(const double* const* prho, ModulePW::PW_Basis* pw_rho
110110
}
111111
delete[] kernelRhoBeta;
112112

113-
return energy;
113+
return this->wt_energy;
114114
}
115115

116116
/**

source/source_pw/module_ofdft/kedf_xwm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ double KEDF_XWM::get_energy(const double* const* prho, ModulePW::PW_Basis* pw_rh
103103
delete[] w1Rho5_6;
104104
delete[] w2Rho5_6;
105105

106-
return energy;
106+
return this->xwm_energy;
107107
}
108108

109109
/**

0 commit comments

Comments
 (0)