Skip to content

Commit 5827cfe

Browse files
committed
1. update Tensor::shape from vector<size_t> to Shape_Vector
1 parent d8b56e0 commit 5827cfe

File tree

8 files changed

+117
-14
lines changed

8 files changed

+117
-14
lines changed

include/RI/global/Shared_Vector.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.02.24
4+
// ===================
5+
6+
#pragma once
7+
8+
#include <cereal/cereal.hpp>
9+
#include <initializer_list>
10+
#include <cassert>
11+
12+
namespace RI
13+
{
14+
15+
class Shape_Vector
16+
{
17+
public:
18+
Shape_Vector()=default;
19+
Shape_Vector(const Shape_Vector &v_in)=default;
20+
Shape_Vector(Shape_Vector &&v_in)=default;
21+
Shape_Vector &operator=(const Shape_Vector &v_in)=default;
22+
Shape_Vector &operator=(Shape_Vector &&v_in)=default;
23+
Shape_Vector(const std::initializer_list<std::size_t> &v_in)
24+
:size_(v_in.size())
25+
{
26+
assert(v_in.size()<=sizeof(v)/sizeof(*v));
27+
std::size_t* ptr_this = this->v;
28+
for(auto ptr_in=v_in.begin(); ptr_in<v_in.end(); )
29+
*(ptr_this++) = *(ptr_in++);
30+
}
31+
32+
const std::size_t* begin() const noexcept { return this->v; }
33+
const std::size_t* end() const noexcept { return this->v+size_; }
34+
std::size_t size() const noexcept { return size_; }
35+
bool empty() const noexcept{ return !size_; }
36+
37+
std::size_t& operator[] (const std::size_t i)
38+
{
39+
assert(i<size_);
40+
return this->v[i];
41+
}
42+
const std::size_t& operator[] (const std::size_t i) const
43+
{
44+
assert(i<size_);
45+
return this->v[i];
46+
}
47+
48+
template <class Archive> void serialize( Archive & ar ){ ar(cereal::binary_data(this->v,sizeof(v)), size_); } // for cereal
49+
50+
public: //private:
51+
std::size_t v[4];
52+
std::size_t size_=0;
53+
};
54+
55+
}

include/RI/global/Tensor.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#pragma once
77

8+
#include "Shared_Vector.h"
89
#include "Global_Func-2.h"
910
#include <memory>
1011
#include <vector>
@@ -16,16 +17,16 @@
1617

1718
namespace RI
1819
{
19-
20+
2021
template<typename T>
2122
class Tensor
2223
{
2324
public:
24-
std::vector<std::size_t> shape;
25+
Shape_Vector shape;
2526
std::shared_ptr<std::valarray<T>> data=nullptr;
2627

27-
explicit inline Tensor (const std::vector<std::size_t> &shape_in);
28-
explicit inline Tensor (const std::vector<std::size_t> &shape_in, std::shared_ptr<std::valarray<T>> data_in);
28+
explicit inline Tensor (const Shape_Vector &shape_in);
29+
explicit inline Tensor (const Shape_Vector &shape_in, std::shared_ptr<std::valarray<T>> data_in);
2930

3031
Tensor()=default;
3132
Tensor(const Tensor<T> &t_in)=default;
@@ -34,7 +35,7 @@ class Tensor
3435
Tensor<T> &operator=(Tensor<T> &&t_in)=default;
3536

3637
inline std::size_t get_shape_all() const;
37-
inline Tensor reshape (const std::vector<std::size_t> &shape_in) const;
38+
inline Tensor reshape (const Shape_Vector &shape_in) const;
3839

3940
Tensor copy() const;
4041

@@ -56,7 +57,7 @@ class Tensor
5657
//Tensor & operator += (const Tensor &);
5758
Tensor operator-() const;
5859

59-
template <class Archive> void serialize( Archive & ar ){ ar(shape); ar(data); } // for cereal
60+
template <class Archive> void serialize( Archive & ar ){ ar(shape, data); } // for cereal
6061
};
6162

6263

@@ -97,7 +98,6 @@ template<typename T, std::size_t N0, std::size_t N1, std::size_t N2>
9798
extern std::array<std::array<std::array<T,N2>,N1>,N0> to_array(const Tensor<T> &t);
9899
template<typename T, std::size_t N0, std::size_t N1, std::size_t N2, std::size_t N3>
99100
extern std::array<std::array<std::array<std::array<T,N3>,N2>,N1>,N0> to_array(const Tensor<T> &t);
100-
101101
}
102102

103103
#include "Blas_Interface-Tensor.h"

include/RI/global/Tensor.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace RI
2020
{
2121

2222
template<typename T>
23-
Tensor<T>::Tensor (const std::vector<std::size_t> &shape_in)
23+
Tensor<T>::Tensor (const Shape_Vector &shape_in)
2424
{
2525
this->shape = shape_in;
2626
if(!this->shape.empty())
@@ -32,7 +32,7 @@ Tensor<T>::Tensor (const std::vector<std::size_t> &shape_in)
3232
}
3333

3434
template<typename T>
35-
Tensor<T>::Tensor (const std::vector<std::size_t> &shape_in, std::shared_ptr<std::valarray<T>> data_in)
35+
Tensor<T>::Tensor (const Shape_Vector &shape_in, std::shared_ptr<std::valarray<T>> data_in)
3636
{
3737
assert( std::accumulate(shape_in.begin(), shape_in.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>() ) == data_in->size() );
3838
this->shape = shape_in;
@@ -46,7 +46,7 @@ std::size_t Tensor<T>::get_shape_all() const
4646
}
4747

4848
template<typename T>
49-
Tensor<T> Tensor<T>::reshape (const std::vector<std::size_t> &shape_in) const
49+
Tensor<T> Tensor<T>::reshape (const Shape_Vector &shape_in) const
5050
{
5151
assert(
5252
std::accumulate(shape_in.begin(), shape_in.end(), static_cast<std::size_t>(1), std::multiplies<std::size_t>())

include/RI/ri/CS_Matrix_Tools.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ namespace CS_Matrix_Tools
8686
std::vector<Tensor<Tdata>> Ds_sub;
8787
Ds_sub.reserve(D.shape[2]);
8888
for(std::size_t i2=0; i2<D.shape[2]; ++i2)
89-
Ds_sub.emplace_back(std::vector<std::size_t>{D.shape[0],D.shape[1]});
89+
Ds_sub.emplace_back(Shape_Vector{D.shape[0],D.shape[1]});
9090

9191
const Tdata* D_ptr = D.ptr();
9292
std::vector<Tdata*> Ds_sub_ptr(D.shape[2]);

unittests/global/Tensor-test-4.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.02.24
4+
// ===================
5+
6+
#include "Tensor-test.h"
7+
#include "Tensor-test-2.hpp"
8+
#include "../include/RI/global/Cereal_Types.h"
9+
10+
#include <Comm/global/Cereal_Func.h>
11+
#include <Comm/global/MPI_Wrapper.h>
12+
13+
#include <mpi.h>
14+
#include <iostream>
15+
16+
namespace Tensor_Test
17+
{
18+
void test_cereal(int argc, char *argv[])
19+
{
20+
int mpi_init_provide;
21+
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_init_provide);
22+
23+
assert(Comm::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD)>=2);
24+
const int rank_mine = Comm::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD);
25+
26+
if(rank_mine==0)
27+
{
28+
const RI::Tensor<double> m = Tensor_Test::init_real_1<double>();
29+
Comm::Cereal_Func::mpi_send(1, 0, MPI_COMM_WORLD, m);
30+
std::cout<<m<<std::endl;
31+
}
32+
else if(rank_mine==1)
33+
{
34+
RI::Tensor<double> m;
35+
Comm::Cereal_Func::mpi_recv(MPI_COMM_WORLD, m);
36+
std::cout<<m<<std::endl;
37+
}
38+
39+
MPI_Finalize();
40+
}
41+
}

unittests/global/Tensor-test.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77

88
#include "RI/global/Tensor.h"
99

10+
std::ostream &operator<<(std::ostream &os, const RI::Shape_Vector &v)
11+
{
12+
for(std::size_t i=0; i<v.size(); ++i)
13+
os<<v[i]<<"\t";
14+
return os;
15+
}
16+
1017
template<typename T>
1118
std::ostream &operator<<(std::ostream &os, const RI::Tensor<T> &t)
1219
{

unittests/ri/LRI-speed-test.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
namespace LRI_Speed_Test
2121
{
2222
template<typename Tdata>
23-
static RI::Tensor<Tdata> init_tensor(const std::vector<std::size_t> &shape)
23+
static RI::Tensor<Tdata> init_tensor(const RI::Shape_Vector &shape)
2424
{
2525
RI::Tensor<Tdata> D(shape);
2626
for(std::size_t i=0; i<D.data->size(); ++i)
@@ -50,7 +50,7 @@ namespace LRI_Speed_Test
5050
using T_Ds = std::map<int, std::map<std::pair<int,std::array<int,Ndim>>, RI::Tensor<Tdata>>>;
5151
std::unordered_map<RI::Label::ab, T_Ds> Ds_ab;
5252
Ds_ab.reserve(11);
53-
auto init_Ds = [&Ds_ab, NA](const RI::Label::ab &label, const std::vector<std::size_t> &shape)
53+
auto init_Ds = [&Ds_ab, NA](const RI::Label::ab &label, const RI::Shape_Vector &shape)
5454
{
5555
const int rank_mine = RI::MPI_Wrapper::mpi_get_rank(MPI_COMM_WORLD);
5656
const int rank_size = RI::MPI_Wrapper::mpi_get_size(MPI_COMM_WORLD);

unittests/ri/LRI-test.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ namespace LRI_Test
4949
};
5050

5151
template<typename Tdata>
52-
static RI::Tensor<Tdata> init_tensor(const std::vector<std::size_t> &shape)
52+
static RI::Tensor<Tdata> init_tensor(const RI::Shape_Vector &shape)
5353
{
5454
RI::Tensor<Tdata> D(shape);
5555
for(std::size_t i=0; i<D.data->size(); ++i)

0 commit comments

Comments
 (0)