|
2 | 2 |
|
3 | 3 | #include "../../../../../build/ninetoothed/relu.h" |
4 | 4 | #include "../../../devices/nvidia/nvidia_common.cuh" |
| 5 | +#include "../../../ninetoothed/utils.h" |
5 | 6 | #include "relu_nvidia.cuh" |
6 | 7 |
|
7 | 8 | namespace op::relu::nvidia { |
@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( |
42 | 43 | } |
43 | 44 |
|
44 | 45 | const auto &ndim{_info.getNdim()}; |
45 | | - const auto &x_shape_{_info.getInputShape(0)}; |
46 | | - const auto &x_strides_{_info.getInputStrides(0)}; |
47 | | - std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; |
48 | | - std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim}; |
49 | | - auto x_data{const_cast<void *>(inputs[0])}; |
50 | | - auto x_shape{x_shape_vec.data()}; |
51 | | - auto x_strides{x_strides_vec.data()}; |
52 | | - const NineToothedTensor x{x_data, x_shape, x_strides}; |
53 | | - const auto &y_shape_{_info.getOutputShape()}; |
54 | | - const auto &y_strides_{_info.getOutputStrides()}; |
55 | | - std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim}; |
56 | | - std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim}; |
57 | | - auto y_data{output}; |
58 | | - auto y_shape{y_shape_vec.data()}; |
59 | | - auto y_strides{y_strides_vec.data()}; |
60 | | - const NineToothedTensor y{y_data, y_shape, y_strides}; |
| 46 | + |
| 47 | + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; |
| 48 | + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; |
| 49 | + |
61 | 50 | constexpr auto block_size{1024}; |
62 | 51 |
|
63 | 52 | switch (_dtype) { |
|
0 commit comments