Skip to content

Commit 6ad28d1

Browse files
committed
Feature: Support ML EXX for training script.
1 parent 47bfd69 commit 6ad28d1

File tree

6 files changed

+88
-58
lines changed

6 files changed

+88
-58
lines changed

source/source_pw/hamilt_ofdft/ml_tools/data.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
208208
if (this->load_tanhxi[ik]){
209209
this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
210210
}
211-
if (this->load_tanhxi_nl[ik{
211+
if (this->load_tanhxi_nl[ik]){
212212
this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
213213
}
214214
if (this->load_tanh_pnl[ik]){
@@ -319,7 +319,16 @@ void Data::load_data_(
319319
enhancement.resize_({this->nx_tot, 1});
320320
pauli.resize_({nx_tot, 1});
321321

322-
this->tau_tf = this->cTF * torch::pow(this->rho, 5./3.);
322+
if (input.energy_type == "kedf")
323+
{
324+
this->tau_exp = 5. / 3.;
325+
this->tau_lda = this->cTF * torch::pow(this->rho, this->tau_exp);
326+
}
327+
else if (input.energy_type == "exx")
328+
{
329+
this->tau_exp = 4. / 3.;
330+
this->tau_lda = this->cDirac * torch::pow(this->rho, this->tau_exp);
331+
}
323332
// Input::print("load_data done");
324333
}
325334

source/source_pw/hamilt_ofdft/ml_tools/data.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Data
1717
// =========== data ===========
1818
torch::Tensor rho;
1919
torch::Tensor nablaRho;
20-
torch::Tensor tau_tf;
20+
torch::Tensor tau_lda; // energy density of LDA, i.e. TF for KEDF, Dirac term for EXX
2121
// semi-local descriptors
2222
torch::Tensor gamma;
2323
torch::Tensor p;
@@ -67,7 +67,9 @@ class Data
6767
void init_data(const int nkernel, const int ndata, const int fftdim, const torch::Device device);
6868
void load_data_(Input &input, const int ndata, const int fftdim, std::string *dir);
6969

70-
const double cTF = 3.0/10.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
70+
const double cTF = 3. /10. * std::pow(3. * std::pow(M_PI, 2.), 2. / 3.) * 2.; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
71+
const double cDirac = - 3. /4. * std::pow(3. / M_PI, 1./3.) * 2.; // -3/4*(3/pi)^{1/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
72+
double tau_exp = 5. / 3.; // 5/3 for TF KEDF, and 4/3 for Dirac term
7173

7274
public:
7375
void loadTensor(std::string file,

source/source_pw/hamilt_ofdft/ml_tools/input.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ void Input::readInput()
277277
{
278278
this->read_value(ifs, this->device_type);
279279
}
280+
else if (strcmp("energy_type", word) == 0)
281+
{
282+
this->read_value(ifs, this->energy_type);
283+
}
280284
}
281285

282286
std::cout << "Read nnINPUT done" << std::endl;

source/source_pw/hamilt_ofdft/ml_tools/input.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Input
7373
double lr_end = 1e-4;
7474
int lr_fre = 5000;
7575
double exponent = 5.; // exponent of weight rho^{exponent/3.}
76+
std::string energy_type = "kedf"; // kedf or exx
7677

7778
// output
7879
int dump_fre = 1;

0 commit comments

Comments
 (0)