@@ -87,6 +87,16 @@ T& Tensor<T>::operator() (const std::size_t i0, const std::size_t i1, const std:
8787 assert (i2>=0 ); assert (i2<this ->shape [2 ]);
8888 return (*this ->data )[(i0*this ->shape [1 ]+i1)*this ->shape [2 ]+i2];
8989}
90+ template <typename T>
91+ T& Tensor<T>::operator () (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const
92+ {
93+ assert (this ->shape .size ()==3 );
94+ assert (i0>=0 ); assert (i0<this ->shape [0 ]);
95+ assert (i1>=0 ); assert (i1<this ->shape [1 ]);
96+ assert (i2>=0 ); assert (i2<this ->shape [2 ]);
97+ assert (i3>=0 ); assert (i3<this ->shape [3 ]);
98+ return (*this ->data )[((i0*this ->shape [1 ]+i1)*this ->shape [2 ]+i2)*this ->shape [3 ]+i3];
99+ }
90100
91101template <typename T1, typename T2>
92102bool same_shape (const Tensor<T1> &t1, const Tensor<T2> &t2)
@@ -232,6 +242,17 @@ Tensor<T> to_Tensor(const std::array<std::array<std::array<T,N2>,N1>,N0> &a)
232242 t (i0,i1,i2) = a[i0][i1][i2];
233243 return t;
234244}
245+ template <typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
246+ Tensor<T> to_Tensor (const std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> &a)
247+ {
248+ Tensor<T> t ({N0,N1,N2,N3});
249+ for (std::size_t i0=0 ; i0<N0; ++i0)
250+ for (std::size_t i1=0 ; i1<N1; ++i1)
251+ for (std::size_t i2=0 ; i2<N2; ++i2)
252+ for (std::size_t i3=0 ; i3<N3; ++i3)
253+ t (i0,i1,i2,i3) = a[i0][i1][i2][i3];
254+ return t;
255+ }
235256
236257template <typename T, std::size_t N0>
237258std::array<T,N0> to_array (const Tensor<T> &t)
@@ -269,5 +290,21 @@ std::array<std::array<std::array<T,N2>,N1>,N0> to_array(const Tensor<T> &t)
269290 a[i0][i1][i2] = t (i0,i1,i2);
270291 return a;
271292}
293+ template <typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
294+ std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> to_array (const Tensor<T> &t)
295+ {
296+ assert (t.shape .size ()==4 );
297+ assert (t.shape [0 ]==N0);
298+ assert (t.shape [1 ]==N1);
299+ assert (t.shape [2 ]==N2);
300+ assert (t.shape [3 ]==N3);
301+ std::array<std::array<T,N1>,N0> a;
302+ for (std::size_t i0=0 ; i0<N0; ++i0)
303+ for (std::size_t i1=0 ; i1<N1; ++i1)
304+ for (std::size_t i2=0 ; i2<N2; ++i2)
305+ for (std::size_t i3=0 ; i3<N3; ++i3)
306+ a[i0][i1][i2][i3] = t (i0,i1,i2,i3);
307+ return a;
308+ }
272309
273310}
0 commit comments