400400namespace torch {
401401namespace executor {
402402
403- using Tensor = exec_aten::Tensor;
404- using Scalar = exec_aten::Scalar;
405- using ScalarType = exec_aten::ScalarType;
406-
407403//
408404// Utility functions for checking tensor attributes
409405//
@@ -432,7 +428,7 @@ inline bool dim_is_valid(int64_t dim, int64_t upper_bound) {
432428 * the zero dimensional tensors in some kernels, that treat them as 1D tensors
433429 * with a single element.
434430 */
435- inline ssize_t nonzero_dim (const Tensor& tensor) {
431+ inline ssize_t nonzero_dim (const exec_aten:: Tensor& tensor) {
436432 return tensor.dim () == 0 ? 1 : tensor.dim ();
437433}
438434
@@ -442,7 +438,7 @@ inline ssize_t nonzero_dim(const Tensor& tensor) {
442438 * the zero dimensional tensors in some kernels, that treat them as 1D tensors
443439 * with a single element.
444440 */
445- inline ssize_t nonempty_size (const Tensor& tensor, ssize_t dim) {
441+ inline ssize_t nonempty_size (const exec_aten:: Tensor& tensor, ssize_t dim) {
446442 return tensor.dim () == 0 ? 1 : tensor.size (dim);
447443}
448444
@@ -861,7 +857,7 @@ inline bool tensor_is_scalar(exec_aten::Tensor t) {
861857constexpr size_t kTensorDimensionLimit = 16 ;
862858
863859// / Returns the product of dim[0:dim), not including dim.
864- inline size_t getLeadingDims (const Tensor& tensor, int64_t dim) {
860+ inline size_t getLeadingDims (const exec_aten:: Tensor& tensor, int64_t dim) {
865861 ET_CHECK_MSG (
866862 dim >= 0 && dim <= tensor.dim (),
867863 " Ending dimension %" PRId64
@@ -876,7 +872,7 @@ inline size_t getLeadingDims(const Tensor& tensor, int64_t dim) {
876872}
877873
878874// / Returns the product of dim[dim+1:].
879- inline size_t getTrailingDims (const Tensor& tensor, int64_t dim) {
875+ inline size_t getTrailingDims (const exec_aten:: Tensor& tensor, int64_t dim) {
880876 ET_CHECK_MSG (
881877 dim >= -1 && dim < tensor.dim (),
882878 " Starting dimension %" PRId64
@@ -901,7 +897,7 @@ inline size_t getTrailingDims(const Tensor& tensor, int64_t dim) {
901897 * the tensor.
902898 */
903899inline size_t coordinateToIndex (
904- const Tensor& tensor,
900+ const exec_aten:: Tensor& tensor,
905901 const size_t * const coordinate) {
906902 size_t index = 0 ;
907903 for (int d = 0 ; d < tensor.dim (); ++d) {
@@ -921,8 +917,10 @@ inline size_t coordinateToIndex(
921917 * index. It is assumed that the array has kTensorDimensionLimit elements.
922918 * @returns void
923919 */
924- inline void
925- indexToCoordinate (const Tensor& tensor, size_t index, size_t * coordinate) {
920+ inline void indexToCoordinate (
921+ const exec_aten::Tensor& tensor,
922+ size_t index,
923+ size_t * coordinate) {
926924 ET_CHECK (index < tensor.numel ());
927925 for (auto i = 0 ; i < tensor.dim (); ++i) {
928926 auto dim = tensor.dim () - 1 - i;
@@ -947,12 +945,12 @@ template <
947945 typename std::enable_if<
948946 std::is_integral<INT_T>::value && !std::is_same<INT_T, bool >::value,
949947 bool >::type = true >
950- bool extract_scalar_tensor (Tensor tensor, INT_T* out_val) {
948+ bool extract_scalar_tensor (exec_aten:: Tensor tensor, INT_T* out_val) {
951949 if (tensor.numel () != 1 ) {
952950 return false ;
953951 }
954952#define CASE_INT_DTYPE (TENSOR_CTYPE, TENSOR_DTYPE ) \
955- case ScalarType::TENSOR_DTYPE: { \
953+ case exec_aten:: ScalarType::TENSOR_DTYPE: { \
956954 const TENSOR_CTYPE val = tensor.const_data_ptr <TENSOR_CTYPE>()[0 ]; \
957955 if (val < std::numeric_limits<INT_T>::lowest () || \
958956 val > std::numeric_limits<INT_T>::max ()) { \
@@ -984,12 +982,12 @@ template <
984982 typename FLOAT_T,
985983 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool >::
986984 type = true >
987- bool extract_scalar_tensor (Tensor tensor, FLOAT_T* out_val) {
985+ bool extract_scalar_tensor (exec_aten:: Tensor tensor, FLOAT_T* out_val) {
988986 if (tensor.numel () != 1 ) {
989987 return false ;
990988 }
991989#define CASE_REAL_DTYPE (TENSOR_CTYPE, TENSOR_DTYPE ) \
992- case ScalarType::TENSOR_DTYPE: { \
990+ case exec_aten:: ScalarType::TENSOR_DTYPE: { \
993991 /* ET_FORALL_REAL_TYPES guarantees TENSOR_CTYPE is a real type. */ \
994992 double val = \
995993 static_cast <double >(tensor.const_data_ptr <TENSOR_CTYPE>()[0 ]); \
@@ -1022,7 +1020,7 @@ template <
10221020 typename BOOL_T,
10231021 typename std::enable_if<std::is_same<BOOL_T, bool >::value, bool >::type =
10241022 true >
1025- bool extract_scalar_tensor (Tensor tensor, BOOL_T* out_val) {
1023+ bool extract_scalar_tensor (exec_aten:: Tensor tensor, BOOL_T* out_val) {
10261024 if (tensor.scalar_type () != exec_aten::ScalarType::Bool) {
10271025 return false ;
10281026 }
0 commit comments