Skip to content

Commit a524007

Browse files
committed
NineToothedTensor 进行 C++ 层封装
1 parent badccb8 commit a524007

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

src/infiniop/ninetoothed/utils.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <initializer_list>
2+
#include <limits>
3+
#include <type_traits>
4+
#include <vector>
5+
6+
namespace ninetoothed {
7+
8+
template <typename T = float>
9+
class Tensor {
10+
public:
11+
using Data = decltype(NineToothedTensor::data);
12+
13+
using Size = std::remove_pointer_t<decltype(NineToothedTensor::shape)>;
14+
15+
using Stride = std::remove_pointer_t<decltype(NineToothedTensor::strides)>;
16+
17+
template <typename Shape, typename Strides>
18+
Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {}
19+
20+
Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {}
21+
22+
Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {}
23+
24+
operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; }
25+
26+
template <typename Shape>
27+
Tensor expand(const Shape &sizes) const {
28+
auto new_ndim{sizes.size()};
29+
30+
decltype(shape_) shape(new_ndim, 1);
31+
decltype(strides_) strides(new_ndim, 0);
32+
33+
auto num_new_dims{new_ndim - ndim_};
34+
35+
for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) {
36+
shape[dim + num_new_dims] = shape_[dim];
37+
strides[dim + num_new_dims] = strides_[dim];
38+
}
39+
40+
for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) {
41+
if (sizes[dim] == std::numeric_limits<std::remove_reference_t<decltype(sizes[dim])>>::max() || shape[dim] != 1) {
42+
continue;
43+
}
44+
45+
shape[dim] = sizes[dim];
46+
strides[dim] = 0;
47+
}
48+
49+
return {data_, shape, strides};
50+
}
51+
52+
Tensor expand_as(const Tensor &other) const {
53+
return expand(other.shape_);
54+
}
55+
56+
private:
57+
const void *data_{nullptr};
58+
59+
std::vector<Size> shape_;
60+
61+
std::vector<Stride> strides_;
62+
63+
Size ndim_{0};
64+
65+
T value_{0};
66+
};
67+
68+
} // namespace ninetoothed

0 commit comments

Comments
 (0)