Skip to content

Commit aee91b3

Browse files
committed
使用 ninetoothed::Tensor 接入九齿的 ReLU 算子
1 parent 9ea44db commit aee91b3

File tree

2 files changed

+10
-32
lines changed

2 files changed

+10
-32
lines changed

src/infiniop/ops/relu/metax/relu_metax.maca

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../../../../../build/ninetoothed/relu.h"
44
#include "../../../devices/metax/metax_common.h"
5+
#include "../../../ninetoothed/utils.h"
56
#include "relu_metax.h"
67

78
namespace op::relu::metax {
@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
4243
}
4344

4445
const auto &ndim{_info.getNdim()};
45-
const auto &x_shape_{_info.getInputShape(0)};
46-
const auto &x_strides_{_info.getInputStrides(0)};
47-
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
48-
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
49-
auto x_data{const_cast<void *>(inputs[0])};
50-
auto x_shape{x_shape_vec.data()};
51-
auto x_strides{x_strides_vec.data()};
52-
const NineToothedTensor x{x_data, x_shape, x_strides};
53-
const auto &y_shape_{_info.getOutputShape()};
54-
const auto &y_strides_{_info.getOutputStrides()};
55-
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
56-
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
57-
auto y_data{output};
58-
auto y_shape{y_shape_vec.data()};
59-
auto y_strides{y_strides_vec.data()};
60-
const NineToothedTensor y{y_data, y_shape, y_strides};
46+
47+
auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
48+
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
49+
6150
constexpr auto block_size{1024};
6251

6352
switch (_dtype) {

src/infiniop/ops/relu/nvidia/relu_nvidia.cu

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../../../../../build/ninetoothed/relu.h"
44
#include "../../../devices/nvidia/nvidia_common.cuh"
5+
#include "../../../ninetoothed/utils.h"
56
#include "relu_nvidia.cuh"
67

78
namespace op::relu::nvidia {
@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
4243
}
4344

4445
const auto &ndim{_info.getNdim()};
45-
const auto &x_shape_{_info.getInputShape(0)};
46-
const auto &x_strides_{_info.getInputStrides(0)};
47-
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
48-
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
49-
auto x_data{const_cast<void *>(inputs[0])};
50-
auto x_shape{x_shape_vec.data()};
51-
auto x_strides{x_strides_vec.data()};
52-
const NineToothedTensor x{x_data, x_shape, x_strides};
53-
const auto &y_shape_{_info.getOutputShape()};
54-
const auto &y_strides_{_info.getOutputStrides()};
55-
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
56-
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
57-
auto y_data{output};
58-
auto y_shape{y_shape_vec.data()};
59-
auto y_strides{y_strides_vec.data()};
60-
const NineToothedTensor y{y_data, y_shape, y_strides};
46+
47+
auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
48+
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
49+
6150
constexpr auto block_size{1024};
6251

6352
switch (_dtype) {

0 commit comments

Comments
 (0)