|
| 1 | +#include <initializer_list> |
| 2 | +#include <limits> |
| 3 | +#include <type_traits> |
| 4 | +#include <vector> |
| 5 | + |
| 6 | +namespace ninetoothed { |
| 7 | + |
| 8 | +template <typename T = float> |
| 9 | +class Tensor { |
| 10 | +public: |
| 11 | + using Data = decltype(NineToothedTensor::data); |
| 12 | + |
| 13 | + using Size = std::remove_pointer_t<decltype(NineToothedTensor::shape)>; |
| 14 | + |
| 15 | + using Stride = std::remove_pointer_t<decltype(NineToothedTensor::strides)>; |
| 16 | + |
| 17 | + template <typename Shape, typename Strides> |
| 18 | + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} |
| 19 | + |
| 20 | + Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} |
| 21 | + |
| 22 | + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} |
| 23 | + |
| 24 | + operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; } |
| 25 | + |
| 26 | + template <typename Shape> |
| 27 | + Tensor expand(const Shape &sizes) const { |
| 28 | + auto new_ndim{sizes.size()}; |
| 29 | + |
| 30 | + decltype(shape_) shape(new_ndim, 1); |
| 31 | + decltype(strides_) strides(new_ndim, 0); |
| 32 | + |
| 33 | + auto num_new_dims{new_ndim - ndim_}; |
| 34 | + |
| 35 | + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { |
| 36 | + shape[dim + num_new_dims] = shape_[dim]; |
| 37 | + strides[dim + num_new_dims] = strides_[dim]; |
| 38 | + } |
| 39 | + |
| 40 | + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { |
| 41 | + if (sizes[dim] == std::numeric_limits<std::remove_reference_t<decltype(sizes[dim])>>::max() || shape[dim] != 1) { |
| 42 | + continue; |
| 43 | + } |
| 44 | + |
| 45 | + shape[dim] = sizes[dim]; |
| 46 | + strides[dim] = 0; |
| 47 | + } |
| 48 | + |
| 49 | + return {data_, shape, strides}; |
| 50 | + } |
| 51 | + |
| 52 | + Tensor expand_as(const Tensor &other) const { |
| 53 | + return expand(other.shape_); |
| 54 | + } |
| 55 | + |
| 56 | +private: |
| 57 | + const void *data_{nullptr}; |
| 58 | + |
| 59 | + std::vector<Size> shape_; |
| 60 | + |
| 61 | + std::vector<Stride> strides_; |
| 62 | + |
| 63 | + Size ndim_{0}; |
| 64 | + |
| 65 | + T value_{0}; |
| 66 | +}; |
| 67 | + |
| 68 | +} // namespace ninetoothed |
0 commit comments