From a524007198cced41415bec1eb71112df8dec849a Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 25 Aug 2025 20:04:27 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=AF=B9=20`NineToothedTensor`=20=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=20C++=20=E5=B1=82=E5=B0=81=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ninetoothed/utils.h | 68 ++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 src/infiniop/ninetoothed/utils.h diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h new file mode 100644 index 000000000..f3b18fd29 --- /dev/null +++ b/src/infiniop/ninetoothed/utils.h @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +namespace ninetoothed { + +template +class Tensor { +public: + using Data = decltype(NineToothedTensor::data); + + using Size = std::remove_pointer_t; + + using Stride = std::remove_pointer_t; + + template + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} + + Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} + + operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } + + template + Tensor expand(const Shape &sizes) const { + auto new_ndim{sizes.size()}; + + decltype(shape_) shape(new_ndim, 1); + decltype(strides_) strides(new_ndim, 0); + + auto num_new_dims{new_ndim - ndim_}; + + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { + shape[dim + num_new_dims] = shape_[dim]; + strides[dim + num_new_dims] = strides_[dim]; + } + + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { + if (sizes[dim] == std::numeric_limits>::max() || shape[dim] != 1) { + continue; + } + + shape[dim] = sizes[dim]; + strides[dim] = 0; + } + + return {data_, shape, strides}; + } + + Tensor expand_as(const Tensor &other) const { + return expand(other.shape_); + } + +private: + const void *data_{nullptr}; + + std::vector shape_; + + std::vector strides_; + + Size ndim_{0}; + + T value_{0}; +}; + +} // namespace ninetoothed From 9ea44db5e0c1e949038c7618cee23180a19dde10 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 25 Aug 2025 20:39:47 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E6=95=B0=E7=BB=84=E4=BD=9C=E4=B8=BA=20`shape`=20=E5=92=8C=20`s?= =?UTF-8?q?trides`=20=E5=88=9B=E5=BB=BA=20`ninetoothed::Tensor`=20?= =?UTF-8?q?=E7=9A=84=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ninetoothed/utils.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h index f3b18fd29..d6a0c012b 100644 --- a/src/infiniop/ninetoothed/utils.h +++ b/src/infiniop/ninetoothed/utils.h @@ -19,6 +19,8 @@ class Tensor { Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {} + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } From aee91b35360bf59c76af9f8ddbbc16f54793537b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 25 Aug 2025 20:41:59 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20`ninetoothed::Tensor`?= =?UTF-8?q?=20=E6=8E=A5=E5=85=A5=E4=B9=9D=E9=BD=BF=E7=9A=84=20ReLU=20?= =?UTF-8?q?=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/relu/metax/relu_metax.maca | 21 +++++---------------- src/infiniop/ops/relu/nvidia/relu_nvidia.cu | 21 +++++---------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/src/infiniop/ops/relu/metax/relu_metax.maca b/src/infiniop/ops/relu/metax/relu_metax.maca index 900fce9e0..2c5104bdd 100644 --- a/src/infiniop/ops/relu/metax/relu_metax.maca +++ b/src/infiniop/ops/relu/metax/relu_metax.maca @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/metax/metax_common.h" +#include "../../../ninetoothed/utils.h" #include "relu_metax.h" namespace op::relu::metax { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu index 5e9151081..417c3371a 100644 --- a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu +++ b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../ninetoothed/utils.h" #include "relu_nvidia.cuh" namespace op::relu::nvidia { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) {