Skip to content

Commit 6190d29

Browse files
committed
Update the file names of nnof
1 parent 54e1d8f commit 6190d29

File tree

5 files changed

+21
-49
lines changed

5 files changed

+21
-49
lines changed

source/module_hamilt_pw/hamilt_ofdft/ml_tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ else()
2121
endif()
2222
include_directories(${libnpy_SOURCE_DIR}/include)
2323

24-
add_executable(nnof main.cpp data.cpp nn_of.cpp grid.cpp input.cpp kernel.cpp pauli_potential.cpp train.cpp)
24+
add_executable(nnof main.cpp data.cpp nn_of.cpp grid.cpp input.cpp kernel.cpp pauli_potential.cpp train_kedf.cpp)
2525
target_link_libraries(nnof "${TORCH_LIBRARIES}")
2626
set_property(TARGET nnof PROPERTY CXX_STANDARD 14)
2727

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
#include "./train.h"
2-
// #include <sstream>
3-
// #include <math.h>
4-
// #include "time.h"
1+
#include "./train_kedf.h"
52

63
int main()
74
{
85
torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble));
96
auto output = torch::get_default_dtype();
107
std::cout << "Default type: " << output << std::endl;
118

12-
Train train;
9+
Train_KEDF train;
1310
train.input.readInput();
1411
if (train.input.check_pot)
1512
{
@@ -20,29 +17,4 @@ int main()
2017
train.init();
2118
train.train();
2219
}
23-
24-
// torch::Tensor x = torch::ones({2,2});
25-
// x[0][0] = 0.;
26-
// x[1][0] = 2.;
27-
// x[1][1] = 3.;
28-
// x.requires_grad_(true);
29-
// std::cout << "x" << x << std::endl;
30-
// torch::Tensor y = x * x + x.t() * x.t();
31-
// std::cout << "y" << y << std::endl;
32-
// std::vector<torch::Tensor> tmp_pot;
33-
// std::vector<torch::Tensor> tmp_ipt;
34-
// std::vector<torch::Tensor> tmp_eye;
35-
// tmp_pot.push_back(y);
36-
// tmp_ipt.push_back(x);
37-
// tmp_eye.push_back(x);
38-
// // tmp_eye.push_back(torch::ones_like(y));
39-
// std::vector<torch::Tensor> tmp_grad;
40-
// tmp_grad = torch::autograd::grad(tmp_pot, tmp_ipt, tmp_eye, true, true, true);
41-
// std::cout << tmp_grad[0] << std::endl;
42-
43-
// load test
44-
// std::shared_ptr<NN_OFImpl> nn = std::make_shared<NN_OFImpl>(64000, 6);
45-
// torch::load(nn, "./net.pt");
46-
// nn->setData(train.gamma, train.gammanl, train.p, train.pnl, train.q, train.qnl);
47-
// std::cout << nn->forward(nn->inputs);
4820
}

source/module_hamilt_pw/hamilt_ofdft/ml_tools/pauli_potential.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef POTENTIAL_H
2-
#define POTENTIAL_H
1+
#ifndef PAULI_POTENTIAL_H
2+
#define PAULI_POTENTIAL_H
33

44
#include <torch/torch.h>
55
#include "./input.h"

source/module_hamilt_pw/hamilt_ofdft/ml_tools/train.cpp renamed to source/module_hamilt_pw/hamilt_ofdft/ml_tools/train_kedf.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
#include "./train.h"
1+
#include "./train_kedf.h"
22
#include <sstream>
33
#include <math.h>
44
#include <chrono>
55

6-
Train::~Train()
6+
Train_KEDF::~Train_KEDF()
77
{
88
delete[] this->train_volume;
99
delete[] this->vali_volume;
1010
delete[] this->kernel_train;
1111
delete[] this->kernel_vali;
1212
}
1313

14-
void Train::setUpFFT()
14+
void Train_KEDF::setUpFFT()
1515
{
1616
this->train_volume = new double[this->input.ntrain];
1717
this->grid_train.initGrid(
@@ -66,7 +66,7 @@ void Train::setUpFFT()
6666
// this->dumpTensor(this->fft_kernel_vali[0].reshape({this->data_train.nx}), "kernel_bcc.npy", this->data_train.nx);
6767
}
6868

69-
void Train::set_device()
69+
void Train_KEDF::set_device()
7070
{
7171
if (this->input.device_type == "cpu")
7272
{
@@ -89,7 +89,7 @@ void Train::set_device()
8989
}
9090
}
9191

92-
void Train::init_input_index()
92+
void Train_KEDF::init_input_index()
9393
{
9494
this->ninput = 0;
9595

@@ -195,7 +195,7 @@ void Train::init_input_index()
195195
std::cout << "feg_limit = " << this->input.feg_limit << std::endl;
196196
}
197197

198-
void Train::init()
198+
void Train_KEDF::init()
199199
{
200200
this->set_device();
201201
this->init_input_index();
@@ -214,18 +214,18 @@ void Train::init()
214214
this->nn->set_data(&(this->data_vali), this->descriptor_type, this->kernel_index, this->nn->input_vali);
215215
}
216216

217-
torch::Tensor Train::lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef)
217+
torch::Tensor Train_KEDF::lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef)
218218
{
219219
return torch::sum(torch::pow(enhancement - target, 2))/this->data_train.nx/coef/coef;
220220
}
221221

222-
torch::Tensor Train::lossFunction_new(torch::Tensor enhancement, torch::Tensor target, torch::Tensor weight, torch::Tensor coef)
222+
torch::Tensor Train_KEDF::lossFunction_new(torch::Tensor enhancement, torch::Tensor target, torch::Tensor weight, torch::Tensor coef)
223223
{
224224
return torch::sum(torch::pow(weight * (enhancement - target), 2.))/this->data_train.nx/coef/coef;
225225
}
226226

227227

228-
void Train::train()
228+
void Train_KEDF::train()
229229
{
230230
// time
231231
double tot = 0.;
@@ -239,7 +239,7 @@ void Train::train()
239239

240240
start = std::chrono::high_resolution_clock::now();
241241

242-
std::cout << "========== Train begin ==========" << std::endl;
242+
std::cout << "========== Train_KEDF begin ==========" << std::endl;
243243
// torch::Tensor target = (this->input.loss=="energy") ? this->data_train.enhancement : this->data_train.pauli;
244244
if (this->input.loss == "potential" || this->input.loss == "both" || this->input.loss == "both_new")
245245
{
@@ -427,7 +427,7 @@ void Train::train()
427427
std::cout << "Step\t\t\t" << totStep << "\t\t" << totStep/tot * 100. << " %" << std::endl;
428428
}
429429

430-
void Train::potTest()
430+
void Train_KEDF::potTest()
431431
{
432432
this->set_device();
433433
this->init_input_index();

source/module_hamilt_pw/hamilt_ofdft/ml_tools/train.h renamed to source/module_hamilt_pw/hamilt_ofdft/ml_tools/train_kedf.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TRAIN_H
2-
#define TRAIN_H
1+
#ifndef TRAIN_KEDF_H
2+
#define TRAIN_KEDF_H
33

44
#include "./data.h"
55
#include "./grid.h"
@@ -10,11 +10,11 @@
1010

1111
#include <torch/torch.h>
1212

13-
class Train
13+
class Train_KEDF
1414
{
1515
public:
16-
Train(){};
17-
~Train();
16+
Train_KEDF(){};
17+
~Train_KEDF();
1818

1919
std::shared_ptr<NN_OFImpl> nn;
2020
Input input;

0 commit comments

Comments
 (0)