Skip to content

Commit 09547e3

Browse files
[pre-commit.ci lite] apply automatic fixes
1 parent 99eae78 commit 09547e3

File tree

2 files changed

+81
-42
lines changed

2 files changed

+81
-42
lines changed

source/module_hamilt_pw/hamilt_ofdft/ml_tools/data.cpp

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Data::~Data()
1717

1818
void Data::load_data(Input &input, const int ndata, std::string *dir, const torch::Device device)
1919
{
20-
if (ndata <= 0) return;
20+
if (ndata <= 0) { return;
21+
}
2122
this->init_label(input);
2223
this->init_data(input.nkernel, ndata, input.fftdim, device);
2324
this->load_data_(input, ndata, input.fftdim, dir);
@@ -29,21 +30,36 @@ void Data::load_data(Input &input, const int ndata, std::string *dir, const torc
2930
}
3031

3132
torch::Tensor Data::get_data(std::string parameter, const int ikernel){
32-
if (parameter == "gamma") return this->gamma.reshape({this->nx_tot});
33-
if (parameter == "p") return this->p.reshape({this->nx_tot});
34-
if (parameter == "q") return this->q.reshape({this->nx_tot});
35-
if (parameter == "tanhp") return this->tanhp.reshape({this->nx_tot});
36-
if (parameter == "tanhq") return this->tanhq.reshape({this->nx_tot});
37-
if (parameter == "gammanl") return this->gammanl[ikernel].reshape({this->nx_tot});
38-
if (parameter == "pnl") return this->pnl[ikernel].reshape({this->nx_tot});
39-
if (parameter == "qnl") return this->qnl[ikernel].reshape({this->nx_tot});
40-
if (parameter == "xi") return this->xi[ikernel].reshape({this->nx_tot});
41-
if (parameter == "tanhxi") return this->tanhxi[ikernel].reshape({this->nx_tot});
42-
if (parameter == "tanhxi_nl") return this->tanhxi_nl[ikernel].reshape({this->nx_tot});
43-
if (parameter == "tanh_pnl") return this->tanh_pnl[ikernel].reshape({this->nx_tot});
44-
if (parameter == "tanh_qnl") return this->tanh_qnl[ikernel].reshape({this->nx_tot});
45-
if (parameter == "tanhp_nl") return this->tanhp_nl[ikernel].reshape({this->nx_tot});
46-
if (parameter == "tanhq_nl") return this->tanhq_nl[ikernel].reshape({this->nx_tot});
33+
if (parameter == "gamma") { return this->gamma.reshape({this->nx_tot});
34+
}
35+
if (parameter == "p") { return this->p.reshape({this->nx_tot});
36+
}
37+
if (parameter == "q") { return this->q.reshape({this->nx_tot});
38+
}
39+
if (parameter == "tanhp") { return this->tanhp.reshape({this->nx_tot});
40+
}
41+
if (parameter == "tanhq") { return this->tanhq.reshape({this->nx_tot});
42+
}
43+
if (parameter == "gammanl") { return this->gammanl[ikernel].reshape({this->nx_tot});
44+
}
45+
if (parameter == "pnl") { return this->pnl[ikernel].reshape({this->nx_tot});
46+
}
47+
if (parameter == "qnl") { return this->qnl[ikernel].reshape({this->nx_tot});
48+
}
49+
if (parameter == "xi") { return this->xi[ikernel].reshape({this->nx_tot});
50+
}
51+
if (parameter == "tanhxi") { return this->tanhxi[ikernel].reshape({this->nx_tot});
52+
}
53+
if (parameter == "tanhxi_nl") { return this->tanhxi_nl[ikernel].reshape({this->nx_tot});
54+
}
55+
if (parameter == "tanh_pnl") { return this->tanh_pnl[ikernel].reshape({this->nx_tot});
56+
}
57+
if (parameter == "tanh_qnl") { return this->tanh_qnl[ikernel].reshape({this->nx_tot});
58+
}
59+
if (parameter == "tanhp_nl") { return this->tanhp_nl[ikernel].reshape({this->nx_tot});
60+
}
61+
if (parameter == "tanhq_nl") { return this->tanhq_nl[ikernel].reshape({this->nx_tot});
62+
}
4763
return torch::zeros({});
4864
}
4965

@@ -123,19 +139,25 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
123139
this->nx_tot = this->nx * ndata;
124140

125141
this->rho = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
126-
if (this->load_p) this->nablaRho = torch::zeros({ndata, 3, fftdim, fftdim, fftdim}).to(device);
142+
if (this->load_p) { this->nablaRho = torch::zeros({ndata, 3, fftdim, fftdim, fftdim}).to(device);
143+
}
127144

128145
this->enhancement = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
129146
this->enhancement_mean = torch::zeros(ndata).to(device);
130147
this->tau_mean = torch::zeros(ndata).to(device);
131148
this->pauli = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
132149
this->pauli_mean = torch::zeros(ndata).to(device);
133150

134-
if (this->load_gamma) this->gamma = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
135-
if (this->load_p) this->p = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
136-
if (this->load_q) this->q = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
137-
if (this->load_tanhp) this->tanhp = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
138-
if (this->load_tanhq) this->tanhq = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
151+
if (this->load_gamma) { this->gamma = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
152+
}
153+
if (this->load_p) { this->p = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
154+
}
155+
if (this->load_q) { this->q = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
156+
}
157+
if (this->load_tanhp) { this->tanhp = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
158+
}
159+
if (this->load_tanhq) { this->tanhq = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
160+
}
139161

140162
for (int ik = 0; ik < nkernel; ++ik)
141163
{
@@ -150,16 +172,26 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
150172
this->tanhp_nl.push_back(torch::zeros({}).to(device));
151173
this->tanhq_nl.push_back(torch::zeros({}).to(device));
152174

153-
if (this->load_gammanl[ik]) this->gammanl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
154-
if (this->load_pnl[ik]) this->pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
155-
if (this->load_qnl[ik]) this->qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
156-
if (this->load_xi[ik]) this->xi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
157-
if (this->load_tanhxi[ik]) this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
158-
if (this->load_tanhxi_nl[ik]) this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
159-
if (this->load_tanh_pnl[ik]) this->tanh_pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
160-
if (this->load_tanh_qnl[ik]) this->tanh_qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
161-
if (this->load_tanhp_nl[ik]) this->tanhp_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
162-
if (this->load_tanhq_nl[ik]) this->tanhq_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
175+
if (this->load_gammanl[ik]) { this->gammanl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
176+
}
177+
if (this->load_pnl[ik]) { this->pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
178+
}
179+
if (this->load_qnl[ik]) { this->qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
180+
}
181+
if (this->load_xi[ik]) { this->xi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
182+
}
183+
if (this->load_tanhxi[ik]) { this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
184+
}
185+
if (this->load_tanhxi_nl[ik]) { this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
186+
}
187+
if (this->load_tanh_pnl[ik]) { this->tanh_pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
188+
}
189+
if (this->load_tanh_qnl[ik]) { this->tanh_qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
190+
}
191+
if (this->load_tanhp_nl[ik]) { this->tanhp_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
192+
}
193+
if (this->load_tanhq_nl[ik]) { this->tanhq_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
194+
}
163195
}
164196

165197
// Input::print("init_data done");
@@ -173,7 +205,8 @@ void Data::load_data_(
173205
)
174206
{
175207
// Input::print("load_data");
176-
if (ndata <= 0) return;
208+
if (ndata <= 0) { return;
209+
}
177210

178211
std::vector<long unsigned int> cshape = {(long unsigned) nx};
179212
std::vector<double> container(nx);
@@ -182,7 +215,8 @@ void Data::load_data_(
182215
for (int idata = 0; idata < ndata; ++idata)
183216
{
184217
this->loadTensor(dir[idata] + "/rho.npy", cshape, fortran_order, container, idata, fftdim, rho);
185-
if (this->load_gamma) this->loadTensor(dir[idata] + "/gamma.npy", cshape, fortran_order, container, idata, fftdim, gamma);
218+
if (this->load_gamma) { this->loadTensor(dir[idata] + "/gamma.npy", cshape, fortran_order, container, idata, fftdim, gamma);
219+
}
186220
if (this->load_p)
187221
{
188222
this->loadTensor(dir[idata] + "/p.npy", cshape, fortran_order, container, idata, fftdim, p);
@@ -193,9 +227,12 @@ void Data::load_data_(
193227
npy::LoadArrayFromNumpy(dir[idata] + "/nablaRhoz.npy", cshape, fortran_order, container);
194228
nablaRho[idata][2] = torch::tensor(container).reshape({fftdim, fftdim, fftdim});
195229
}
196-
if (this->load_q) this->loadTensor(dir[idata] + "/q.npy", cshape, fortran_order, container, idata, fftdim, q);
197-
if (this->load_tanhp) this->loadTensor(dir[idata] + "/tanhp.npy", cshape, fortran_order, container, idata, fftdim, tanhp);
198-
if (this->load_tanhq) this->loadTensor(dir[idata] + "/tanhq.npy", cshape, fortran_order, container, idata, fftdim, tanhq);
230+
if (this->load_q) { this->loadTensor(dir[idata] + "/q.npy", cshape, fortran_order, container, idata, fftdim, q);
231+
}
232+
if (this->load_tanhp) { this->loadTensor(dir[idata] + "/tanhp.npy", cshape, fortran_order, container, idata, fftdim, tanhp);
233+
}
234+
if (this->load_tanhq) { this->loadTensor(dir[idata] + "/tanhq.npy", cshape, fortran_order, container, idata, fftdim, tanhq);
235+
}
199236

200237
for (int ik = 0; ik < input.nkernel; ++ik)
201238
{
@@ -268,7 +305,8 @@ void Data::loadTensor(
268305
void Data::dumpTensor(const torch::Tensor &data, std::string filename, int nx)
269306
{
270307
std::vector<double> v(nx);
271-
for (int ir = 0; ir < nx; ++ir) v[ir] = data[ir].item<double>();
308+
for (int ir = 0; ir < nx; ++ir) { v[ir] = data[ir].item<double>();
309+
}
272310
// std::vector<double> v(data.data_ptr<float>(), data.data_ptr<float>() + data.numel()); // this works, but only supports float tensor
273311
const long unsigned cshape[] = {(long unsigned) nx}; // shape
274312
npy::SaveArrayAsNumpy(filename, false, 1, cshape, v);

source/module_hamilt_pw/hamilt_ofdft/ml_tools/kernel.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ void Kernel::read_kernel(const int fftdim,
177177
eta = eta * this->scaling;
178178
maxEta = std::max(eta, maxEta);
179179

180-
if (eta <= eta_in[0])
180+
if (eta <= eta_in[0]) {
181181
this->kernel[id][ix][iy][iz] = w0_in[0];
182-
else if (eta > maxEta_in)
182+
} else if (eta > maxEta_in) {
183183
this->kernel[id][ix][iy][iz] = w0_in[nq_in-1];
184-
else
184+
} else
185185
{
186186
ind1 = 1;
187187
ind2 = nq_in;
@@ -205,7 +205,8 @@ void Kernel::read_kernel(const int fftdim,
205205
}
206206
}
207207
}
208-
if (maxEta > maxEta_in) std::cout << "Please increase the maximal eta value in KEDF kernel file" << std::endl;
208+
if (maxEta > maxEta_in) { std::cout << "Please increase the maximal eta value in KEDF kernel file" << std::endl;
209+
}
209210
}
210211

211212

0 commit comments

Comments
 (0)