@@ -17,8 +17,7 @@ Data::~Data()
1717
1818void Data::load_data (Input &input, const int ndata, std::string *dir, const torch::Device device)
1919{
20- if (ndata <= 0 ) { return ;
21- }
20+ if (ndata <= 0 ) return ;
2221 this ->init_label (input);
2322 this ->init_data (input.nkernel , ndata, input.fftdim , device);
2423 this ->load_data_ (input, ndata, input.fftdim , dir);
@@ -30,36 +29,51 @@ void Data::load_data(Input &input, const int ndata, std::string *dir, const torc
3029}
3130
3231torch::Tensor Data::get_data (std::string parameter, const int ikernel){
33- if (parameter == " gamma" ) { return this ->gamma .reshape ({this ->nx_tot });
34- }
35- if (parameter == " p" ) { return this ->p .reshape ({this ->nx_tot });
36- }
37- if (parameter == " q" ) { return this ->q .reshape ({this ->nx_tot });
38- }
39- if (parameter == " tanhp" ) { return this ->tanhp .reshape ({this ->nx_tot });
40- }
41- if (parameter == " tanhq" ) { return this ->tanhq .reshape ({this ->nx_tot });
42- }
43- if (parameter == " gammanl" ) { return this ->gammanl [ikernel].reshape ({this ->nx_tot });
44- }
45- if (parameter == " pnl" ) { return this ->pnl [ikernel].reshape ({this ->nx_tot });
46- }
47- if (parameter == " qnl" ) { return this ->qnl [ikernel].reshape ({this ->nx_tot });
48- }
49- if (parameter == " xi" ) { return this ->xi [ikernel].reshape ({this ->nx_tot });
50- }
51- if (parameter == " tanhxi" ) { return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
52- }
53- if (parameter == " tanhxi_nl" ) { return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
54- }
55- if (parameter == " tanh_pnl" ) { return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
56- }
57- if (parameter == " tanh_qnl" ) { return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
58- }
59- if (parameter == " tanhp_nl" ) { return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
60- }
61- if (parameter == " tanhq_nl" ) { return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
62- }
32+ if (parameter == " gamma" ){
33+ return this ->gamma .reshape ({this ->nx_tot });
34+ }
35+ if (parameter == " p" ){
36+ return this ->p .reshape ({this ->nx_tot });
37+ }
38+ if (parameter == " q" ){
39+ return this ->q .reshape ({this ->nx_tot });
40+ }
41+ if (parameter == " tanhp" ){
42+ return this ->tanhp .reshape ({this ->nx_tot });
43+ }
44+ if (parameter == " tanhq" ){
45+ return this ->tanhq .reshape ({this ->nx_tot });
46+ }
47+ if (parameter == " gammanl" ){
48+ return this ->gammanl [ikernel].reshape ({this ->nx_tot });
49+ }
50+ if (parameter == " pnl" ){
51+ return this ->pnl [ikernel].reshape ({this ->nx_tot });
52+ }
53+ if (parameter == " qnl" ){
54+ return this ->qnl [ikernel].reshape ({this ->nx_tot });
55+ }
56+ if (parameter == " xi" ){
57+ return this ->xi [ikernel].reshape ({this ->nx_tot });
58+ }
59+ if (parameter == " tanhxi" ){
60+ return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
61+ }
62+ if (parameter == " tanhxi_nl" ){
63+ return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
64+ }
65+ if (parameter == " tanh_pnl" ){
66+ return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
67+ }
68+ if (parameter == " tanh_qnl" ){
69+ return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
70+ }
71+ if (parameter == " tanhp_nl" ){
72+ return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
73+ }
74+ if (parameter == " tanhq_nl" ){
75+ return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
76+ }
6377 return torch::zeros ({});
6478}
6579
@@ -139,25 +153,31 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
139153 this ->nx_tot = this ->nx * ndata;
140154
141155 this ->rho = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
142- if (this ->load_p ) { this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
143- }
156+ if (this ->load_p ){
157+ this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
158+ }
144159
145160 this ->enhancement = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
146161 this ->enhancement_mean = torch::zeros (ndata).to (device);
147162 this ->tau_mean = torch::zeros (ndata).to (device);
148163 this ->pauli = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
149164 this ->pauli_mean = torch::zeros (ndata).to (device);
150165
151- if (this ->load_gamma ) { this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
152- }
153- if (this ->load_p ) { this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
154- }
155- if (this ->load_q ) { this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
156- }
157- if (this ->load_tanhp ) { this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
158- }
159- if (this ->load_tanhq ) { this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
160- }
166+ if (this ->load_gamma ){
167+ this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
168+ }
169+ if (this ->load_p ){
170+ this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
171+ }
172+ if (this ->load_q ){
173+ this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
174+ }
175+ if (this ->load_tanhp ){
176+ this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
177+ }
178+ if (this ->load_tanhq ){
179+ this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
180+ }
161181
162182 for (int ik = 0 ; ik < nkernel; ++ik)
163183 {
@@ -172,26 +192,36 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
172192 this ->tanhp_nl .push_back (torch::zeros ({}).to (device));
173193 this ->tanhq_nl .push_back (torch::zeros ({}).to (device));
174194
175- if (this ->load_gammanl [ik]) { this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
176- }
177- if (this ->load_pnl [ik]) { this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
178- }
179- if (this ->load_qnl [ik]) { this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
180- }
181- if (this ->load_xi [ik]) { this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
182- }
183- if (this ->load_tanhxi [ik]) { this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
184- }
185- if (this ->load_tanhxi_nl [ik]) { this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
186- }
187- if (this ->load_tanh_pnl [ik]) { this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
188- }
189- if (this ->load_tanh_qnl [ik]) { this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
190- }
191- if (this ->load_tanhp_nl [ik]) { this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
192- }
193- if (this ->load_tanhq_nl [ik]) { this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
194- }
195+ if (this ->load_gammanl [ik]){
196+ this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
197+ }
198+ if (this ->load_pnl [ik]){
199+ this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
200+ }
201+ if (this ->load_qnl [ik]){
202+ this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
203+ }
204+ if (this ->load_xi [ik]){
205+ this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
206+ }
207+ if (this ->load_tanhxi [ik]){
208+ this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
209+ }
210+ if (this ->load_tanhxi_nl [ik{
211+ this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
212+ }
213+ if (this ->load_tanh_pnl [ik]){
214+ this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
215+ }
216+ if (this ->load_tanh_qnl [ik]){
217+ this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
218+ }
219+ if (this ->load_tanhp_nl [ik]){
220+ this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
221+ }
222+ if (this ->load_tanhq_nl [ik]){
223+ this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
224+ }
195225 }
196226
197227 // Input::print("init_data done");
@@ -205,20 +235,21 @@ void Data::load_data_(
205235)
206236{
207237 // Input::print("load_data");
208- if (ndata <= 0 ) { return ;
209- }
210-
238+ if (ndata <= 0 ){
239+ return ;
240+ }
241+
211242 std::vector<long unsigned int > cshape = {(long unsigned ) nx};
212243 std::vector<double > container (nx);
213244 bool fortran_order = false ;
214245
215246 for (int idata = 0 ; idata < ndata; ++idata)
216247 {
217248 this ->loadTensor (dir[idata] + " /rho.npy" , cshape, fortran_order, container, idata, fftdim, rho);
218- if (this ->load_gamma ) { this -> loadTensor (dir[idata] + " /gamma.npy " , cshape, fortran_order, container, idata, fftdim, gamma);
219- }
220- if ( this -> load_p )
221- {
249+ if (this ->load_gamma ){
250+ this -> loadTensor (dir[idata] + " /gamma.npy " , cshape, fortran_order, container, idata, fftdim, gamma);
251+ }
252+ if ( this -> load_p ) {
222253 this ->loadTensor (dir[idata] + " /p.npy" , cshape, fortran_order, container, idata, fftdim, p);
223254 npy::LoadArrayFromNumpy (dir[idata] + " /nablaRhox.npy" , cshape, fortran_order, container);
224255 nablaRho[idata][0 ] = torch::tensor (container).reshape ({fftdim, fftdim, fftdim});
@@ -227,12 +258,15 @@ void Data::load_data_(
227258 npy::LoadArrayFromNumpy (dir[idata] + " /nablaRhoz.npy" , cshape, fortran_order, container);
228259 nablaRho[idata][2 ] = torch::tensor (container).reshape ({fftdim, fftdim, fftdim});
229260 }
230- if (this ->load_q ) { this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
231- }
232- if (this ->load_tanhp ) { this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
233- }
234- if (this ->load_tanhq ) { this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
235- }
261+ if (this ->load_q ){
262+ this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
263+ }
264+ if (this ->load_tanhp ){
265+ this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
266+ }
267+ if (this ->load_tanhq ){
268+ this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
269+ }
236270
237271 for (int ik = 0 ; ik < input.nkernel ; ++ik)
238272 {
@@ -305,8 +339,9 @@ void Data::loadTensor(
305339void Data::dumpTensor (const torch::Tensor &data, std::string filename, int nx)
306340{
307341 std::vector<double > v (nx);
308- for (int ir = 0 ; ir < nx; ++ir) { v[ir] = data[ir].item <double >();
309- }
342+ for (int ir = 0 ; ir < nx; ++ir){
343+ v[ir] = data[ir].item <double >();
344+ }
310345 // std::vector<double> v(data.data_ptr<float>(), data.data_ptr<float>() + data.numel()); // this works, but only supports float tensor
311346 const long unsigned cshape[] = {(long unsigned ) nx}; // shape
312347 npy::SaveArrayAsNumpy (filename, false , 1 , cshape, v);
0 commit comments