Skip to content

Commit d8b56e0

Browse files
committed
1. update Tensor
1 parent fbe7d0a commit d8b56e0

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

include/RI/global/Tensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Tensor
4141
inline T& operator() (const std::size_t i0) const;
4242
inline T& operator() (const std::size_t i0, const std::size_t i1) const;
4343
inline T& operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2) const;
44+
inline T& operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const;
4445

4546
Tensor transpose() const;
4647

@@ -85,13 +86,17 @@ template<typename T, std::size_t N0, std::size_t N1>
8586
extern Tensor<T> to_Tensor(const std::array<std::array<T,N1>,N0> &a);
8687
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2>
8788
extern Tensor<T> to_Tensor(const std::array<std::array<std::array<T,N2>,N1>,N0> &a);
89+
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
90+
extern Tensor<T> to_Tensor(const std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> &a);
8891

8992
template<typename T, std::size_t N0>
9093
extern std::array<T,N0> to_array(const Tensor<T> &t);
9194
template<typename T, std::size_t N0, std::size_t N1>
9295
extern std::array<std::array<T,N1>,N0> to_array(const Tensor<T> &t);
9396
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2>
9497
extern std::array<std::array<std::array<T,N2>,N1>,N0> to_array(const Tensor<T> &t);
98+
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
99+
extern std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> to_array(const Tensor<T> &t);
95100

96101
}
97102

include/RI/global/Tensor.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ T& Tensor<T>::operator() (const std::size_t i0, const std::size_t i1, const std:
8787
assert(i2>=0); assert(i2<this->shape[2]);
8888
return (*this->data)[(i0*this->shape[1]+i1)*this->shape[2]+i2];
8989
}
90+
template<typename T>
91+
T& Tensor<T>::operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const
92+
{
93+
assert(this->shape.size()==3);
94+
assert(i0>=0); assert(i0<this->shape[0]);
95+
assert(i1>=0); assert(i1<this->shape[1]);
96+
assert(i2>=0); assert(i2<this->shape[2]);
97+
assert(i3>=0); assert(i3<this->shape[3]);
98+
return (*this->data)[((i0*this->shape[1]+i1)*this->shape[2]+i2)*this->shape[3]+i3];
99+
}
90100

91101
template<typename T1, typename T2>
92102
bool same_shape (const Tensor<T1> &t1, const Tensor<T2> &t2)
@@ -232,6 +242,17 @@ Tensor<T> to_Tensor(const std::array<std::array<std::array<T,N2>,N1>,N0> &a)
232242
t(i0,i1,i2) = a[i0][i1][i2];
233243
return t;
234244
}
245+
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
246+
Tensor<T> to_Tensor(const std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> &a)
247+
{
248+
Tensor<T> t({N0,N1,N2,N3});
249+
for(std::size_t i0=0; i0<N0; ++i0)
250+
for(std::size_t i1=0; i1<N1; ++i1)
251+
for(std::size_t i2=0; i2<N2; ++i2)
252+
for(std::size_t i3=0; i3<N3; ++i3)
253+
t(i0,i1,i2,i3) = a[i0][i1][i2][i3];
254+
return t;
255+
}
235256

236257
template<typename T, std::size_t N0>
237258
std::array<T,N0> to_array(const Tensor<T> &t)
@@ -269,5 +290,21 @@ std::array<std::array<std::array<T,N2>,N1>,N0> to_array(const Tensor<T> &t)
269290
a[i0][i1][i2] = t(i0,i1,i2);
270291
return a;
271292
}
293+
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
294+
std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> to_array(const Tensor<T> &t)
295+
{
296+
assert(t.shape.size()==4);
297+
assert(t.shape[0]==N0);
298+
assert(t.shape[1]==N1);
299+
assert(t.shape[2]==N2);
300+
assert(t.shape[3]==N3);
301+
std::array<std::array<T,N1>,N0> a;
302+
for(std::size_t i0=0; i0<N0; ++i0)
303+
for(std::size_t i1=0; i1<N1; ++i1)
304+
for(std::size_t i2=0; i2<N2; ++i2)
305+
for(std::size_t i3=0; i3<N3; ++i3)
306+
a[i0][i1][i2][i3] = t(i0,i1,i2,i3);
307+
return a;
308+
}
272309

273310
}

0 commit comments

Comments
 (0)