Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions source/source_pw/module_ofdft/kedf_ml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,24 @@ void KEDF_ML::NN_forward(const double * const * prho, ModulePW::PW_Basis *pw_rho

void KEDF_ML::loadVector(std::string filename, std::vector<double> &data)
{
std::vector<long unsigned int> cshape = {(long unsigned) this->cal_tool->nx};
bool fortran_order = false;
npy::LoadArrayFromNumpy(filename, cshape, fortran_order, data);
npy::npy_data<double> d = npy::read_npy<double>(filename);
data = d.data;
// ========== For old version of npy.hpp ==========
// std::vector<long unsigned int> cshape = {(long unsigned) this->cal_tool->nx};
// bool fortran_order = false;
// npy::LoadArrayFromNumpy(filename, cshape, fortran_order, data);
}

void KEDF_ML::dumpVector(std::string filename, const std::vector<double> &data)
{
const long unsigned cshape[] = {(long unsigned) this->cal_tool->nx}; // shape
npy::SaveArrayAsNumpy(filename, false, 1, cshape, data);
npy::npy_data_ptr<double> d;
d.data_ptr = data.data();
d.shape = {(long unsigned) this->cal_tool->nx};
d.fortran_order = false; // optional
npy::write_npy(filename, d);
// ========== For old version of npy.hpp ==========
// const long unsigned cshape[] = {(long unsigned) this->cal_tool->nx}; // shape
// npy::SaveArrayAsNumpy(filename, false, 1, cshape, data);
}

/**
Expand Down
13 changes: 11 additions & 2 deletions source/source_pw/module_ofdft/ml_tools/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
if (this->load_tanhxi[ik]){
this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
}
if (this->load_tanhxi_nl[ik{
if (this->load_tanhxi_nl[ik]){
this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
}
if (this->load_tanh_pnl[ik]){
Expand Down Expand Up @@ -319,7 +319,16 @@ void Data::load_data_(
enhancement.resize_({this->nx_tot, 1});
pauli.resize_({nx_tot, 1});

this->tau_tf = this->cTF * torch::pow(this->rho, 5./3.);
if (input.energy_type == "kedf")
{
this->tau_exp = 5. / 3.;
this->tau_lda = this->cTF * torch::pow(this->rho, this->tau_exp);
}
else if (input.energy_type == "exx")
{
this->tau_exp = 4. / 3.;
this->tau_lda = this->cDirac * torch::pow(this->rho, this->tau_exp);
}
// Input::print("load_data done");
}

Expand Down
6 changes: 4 additions & 2 deletions source/source_pw/module_ofdft/ml_tools/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Data
// =========== data ===========
torch::Tensor rho;
torch::Tensor nablaRho;
torch::Tensor tau_tf;
torch::Tensor tau_lda; // energy density of LDA, i.e. TF for KEDF, Dirac term for EXX
// semi-local descriptors
torch::Tensor gamma;
torch::Tensor p;
Expand Down Expand Up @@ -67,7 +67,9 @@ class Data
void init_data(const int nkernel, const int ndata, const int fftdim, const torch::Device device);
void load_data_(Input &input, const int ndata, const int fftdim, std::string *dir);

const double cTF = 3.0/10.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
const double cTF = 3. /10. * std::pow(3. * std::pow(M_PI, 2.), 2. / 3.) * 2.; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
const double cDirac = - 3. /4. * std::pow(3. / M_PI, 1./3.) * 2.; // -3/4*(3/pi)^{1/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
double tau_exp = 5. / 3.; // 5/3 for TF KEDF, and 4/3 for Dirac term

public:
void loadTensor(std::string file,
Expand Down
4 changes: 4 additions & 0 deletions source/source_pw/module_ofdft/ml_tools/input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ void Input::readInput()
{
this->read_value(ifs, this->device_type);
}
else if (strcmp("energy_type", word) == 0)
{
this->read_value(ifs, this->energy_type);
}
}

std::cout << "Read nnINPUT done" << std::endl;
Expand Down
1 change: 1 addition & 0 deletions source/source_pw/module_ofdft/ml_tools/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Input
double lr_end = 1e-4;
int lr_fre = 5000;
double exponent = 5.; // exponent of weight rho^{exponent/3.}
std::string energy_type = "kedf"; // kedf or exx

// output
int dump_fre = 1;
Expand Down
Loading
Loading