Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/infiniop/ninetoothed/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <initializer_list>
#include <limits>
#include <type_traits>
#include <vector>

namespace ninetoothed {

template <typename T = float>
class Tensor {
public:
using Data = decltype(NineToothedTensor::data);

using Size = std::remove_pointer_t<decltype(NineToothedTensor::shape)>;

using Stride = std::remove_pointer_t<decltype(NineToothedTensor::strides)>;

template <typename Shape, typename Strides>
Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {}

Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {}

Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {}

Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {}

operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; }

template <typename Shape>
Tensor expand(const Shape &sizes) const {
auto new_ndim{sizes.size()};

decltype(shape_) shape(new_ndim, 1);
decltype(strides_) strides(new_ndim, 0);

auto num_new_dims{new_ndim - ndim_};

for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) {
shape[dim + num_new_dims] = shape_[dim];
strides[dim + num_new_dims] = strides_[dim];
}

for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) {
if (sizes[dim] == std::numeric_limits<std::remove_reference_t<decltype(sizes[dim])>>::max() || shape[dim] != 1) {
continue;
}

shape[dim] = sizes[dim];
strides[dim] = 0;
}

return {data_, shape, strides};
}

Tensor expand_as(const Tensor &other) const {
return expand(other.shape_);
}

private:
const void *data_{nullptr};

std::vector<Size> shape_;

std::vector<Stride> strides_;

Size ndim_{0};

T value_{0};
};

} // namespace ninetoothed
21 changes: 5 additions & 16 deletions src/infiniop/ops/relu/metax/relu_metax.maca
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../../../../build/ninetoothed/relu.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../ninetoothed/utils.h"
#include "relu_metax.h"

namespace op::relu::metax {
Expand Down Expand Up @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
}

const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};

auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};

constexpr auto block_size{1024};

switch (_dtype) {
Expand Down
21 changes: 5 additions & 16 deletions src/infiniop/ops/relu/nvidia/relu_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../../../../build/ninetoothed/relu.h"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../ninetoothed/utils.h"
#include "relu_nvidia.cuh"

namespace op::relu::nvidia {
Expand Down Expand Up @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
}

const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};

auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};

constexpr auto block_size{1024};

switch (_dtype) {
Expand Down
Loading