@@ -17,7 +17,8 @@ Data::~Data()
1717
1818void Data::load_data (Input &input, const int ndata, std::string *dir, const torch::Device device)
1919{
20- if (ndata <= 0 ) return ;
20+ if (ndata <= 0 ) { return ;
21+ }
2122 this ->init_label (input);
2223 this ->init_data (input.nkernel , ndata, input.fftdim , device);
2324 this ->load_data_ (input, ndata, input.fftdim , dir);
@@ -29,21 +30,36 @@ void Data::load_data(Input &input, const int ndata, std::string *dir, const torc
2930}
3031
3132torch::Tensor Data::get_data (std::string parameter, const int ikernel){
32- if (parameter == " gamma" ) return this ->gamma .reshape ({this ->nx_tot });
33- if (parameter == " p" ) return this ->p .reshape ({this ->nx_tot });
34- if (parameter == " q" ) return this ->q .reshape ({this ->nx_tot });
35- if (parameter == " tanhp" ) return this ->tanhp .reshape ({this ->nx_tot });
36- if (parameter == " tanhq" ) return this ->tanhq .reshape ({this ->nx_tot });
37- if (parameter == " gammanl" ) return this ->gammanl [ikernel].reshape ({this ->nx_tot });
38- if (parameter == " pnl" ) return this ->pnl [ikernel].reshape ({this ->nx_tot });
39- if (parameter == " qnl" ) return this ->qnl [ikernel].reshape ({this ->nx_tot });
40- if (parameter == " xi" ) return this ->xi [ikernel].reshape ({this ->nx_tot });
41- if (parameter == " tanhxi" ) return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
42- if (parameter == " tanhxi_nl" ) return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
43- if (parameter == " tanh_pnl" ) return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
44- if (parameter == " tanh_qnl" ) return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
45- if (parameter == " tanhp_nl" ) return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
46- if (parameter == " tanhq_nl" ) return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
33+ if (parameter == " gamma" ) { return this ->gamma .reshape ({this ->nx_tot });
34+ }
35+ if (parameter == " p" ) { return this ->p .reshape ({this ->nx_tot });
36+ }
37+ if (parameter == " q" ) { return this ->q .reshape ({this ->nx_tot });
38+ }
39+ if (parameter == " tanhp" ) { return this ->tanhp .reshape ({this ->nx_tot });
40+ }
41+ if (parameter == " tanhq" ) { return this ->tanhq .reshape ({this ->nx_tot });
42+ }
43+ if (parameter == " gammanl" ) { return this ->gammanl [ikernel].reshape ({this ->nx_tot });
44+ }
45+ if (parameter == " pnl" ) { return this ->pnl [ikernel].reshape ({this ->nx_tot });
46+ }
47+ if (parameter == " qnl" ) { return this ->qnl [ikernel].reshape ({this ->nx_tot });
48+ }
49+ if (parameter == " xi" ) { return this ->xi [ikernel].reshape ({this ->nx_tot });
50+ }
51+ if (parameter == " tanhxi" ) { return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
52+ }
53+ if (parameter == " tanhxi_nl" ) { return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
54+ }
55+ if (parameter == " tanh_pnl" ) { return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
56+ }
57+ if (parameter == " tanh_qnl" ) { return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
58+ }
59+ if (parameter == " tanhp_nl" ) { return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
60+ }
61+ if (parameter == " tanhq_nl" ) { return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
62+ }
4763 return torch::zeros ({});
4864}
4965
@@ -123,19 +139,25 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
123139 this ->nx_tot = this ->nx * ndata;
124140
125141 this ->rho = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
126- if (this ->load_p ) this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
142+ if (this ->load_p ) { this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
143+ }
127144
128145 this ->enhancement = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
129146 this ->enhancement_mean = torch::zeros (ndata).to (device);
130147 this ->tau_mean = torch::zeros (ndata).to (device);
131148 this ->pauli = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
132149 this ->pauli_mean = torch::zeros (ndata).to (device);
133150
134- if (this ->load_gamma ) this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
135- if (this ->load_p ) this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
136- if (this ->load_q ) this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
137- if (this ->load_tanhp ) this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
138- if (this ->load_tanhq ) this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
151+ if (this ->load_gamma ) { this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
152+ }
153+ if (this ->load_p ) { this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
154+ }
155+ if (this ->load_q ) { this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
156+ }
157+ if (this ->load_tanhp ) { this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
158+ }
159+ if (this ->load_tanhq ) { this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
160+ }
139161
140162 for (int ik = 0 ; ik < nkernel; ++ik)
141163 {
@@ -150,16 +172,26 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
150172 this ->tanhp_nl .push_back (torch::zeros ({}).to (device));
151173 this ->tanhq_nl .push_back (torch::zeros ({}).to (device));
152174
153- if (this ->load_gammanl [ik]) this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
154- if (this ->load_pnl [ik]) this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
155- if (this ->load_qnl [ik]) this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
156- if (this ->load_xi [ik]) this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
157- if (this ->load_tanhxi [ik]) this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
158- if (this ->load_tanhxi_nl [ik]) this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
159- if (this ->load_tanh_pnl [ik]) this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
160- if (this ->load_tanh_qnl [ik]) this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
161- if (this ->load_tanhp_nl [ik]) this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
162- if (this ->load_tanhq_nl [ik]) this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
175+ if (this ->load_gammanl [ik]) { this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
176+ }
177+ if (this ->load_pnl [ik]) { this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
178+ }
179+ if (this ->load_qnl [ik]) { this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
180+ }
181+ if (this ->load_xi [ik]) { this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
182+ }
183+ if (this ->load_tanhxi [ik]) { this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
184+ }
185+ if (this ->load_tanhxi_nl [ik]) { this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
186+ }
187+ if (this ->load_tanh_pnl [ik]) { this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
188+ }
189+ if (this ->load_tanh_qnl [ik]) { this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
190+ }
191+ if (this ->load_tanhp_nl [ik]) { this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
192+ }
193+ if (this ->load_tanhq_nl [ik]) { this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
194+ }
163195 }
164196
165197 // Input::print("init_data done");
@@ -173,7 +205,8 @@ void Data::load_data_(
173205)
174206{
175207 // Input::print("load_data");
176- if (ndata <= 0 ) return ;
208+ if (ndata <= 0 ) { return ;
209+ }
177210
178211 std::vector<long unsigned int > cshape = {(long unsigned ) nx};
179212 std::vector<double > container (nx);
@@ -182,7 +215,8 @@ void Data::load_data_(
182215 for (int idata = 0 ; idata < ndata; ++idata)
183216 {
184217 this ->loadTensor (dir[idata] + " /rho.npy" , cshape, fortran_order, container, idata, fftdim, rho);
185- if (this ->load_gamma ) this ->loadTensor (dir[idata] + " /gamma.npy" , cshape, fortran_order, container, idata, fftdim, gamma);
218+ if (this ->load_gamma ) { this ->loadTensor (dir[idata] + " /gamma.npy" , cshape, fortran_order, container, idata, fftdim, gamma);
219+ }
186220 if (this ->load_p )
187221 {
188222 this ->loadTensor (dir[idata] + " /p.npy" , cshape, fortran_order, container, idata, fftdim, p);
@@ -193,9 +227,12 @@ void Data::load_data_(
193227 npy::LoadArrayFromNumpy (dir[idata] + " /nablaRhoz.npy" , cshape, fortran_order, container);
194228 nablaRho[idata][2 ] = torch::tensor (container).reshape ({fftdim, fftdim, fftdim});
195229 }
196- if (this ->load_q ) this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
197- if (this ->load_tanhp ) this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
198- if (this ->load_tanhq ) this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
230+ if (this ->load_q ) { this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
231+ }
232+ if (this ->load_tanhp ) { this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
233+ }
234+ if (this ->load_tanhq ) { this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
235+ }
199236
200237 for (int ik = 0 ; ik < input.nkernel ; ++ik)
201238 {
@@ -268,7 +305,8 @@ void Data::loadTensor(
268305void Data::dumpTensor (const torch::Tensor &data, std::string filename, int nx)
269306{
270307 std::vector<double > v (nx);
271- for (int ir = 0 ; ir < nx; ++ir) v[ir] = data[ir].item <double >();
308+ for (int ir = 0 ; ir < nx; ++ir) { v[ir] = data[ir].item <double >();
309+ }
272310 // std::vector<double> v(data.data_ptr<float>(), data.data_ptr<float>() + data.numel()); // this works, but only supports float tensor
273311 const long unsigned cshape[] = {(long unsigned ) nx}; // shape
274312 npy::SaveArrayAsNumpy (filename, false , 1 , cshape, v);
0 commit comments