|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | +#include <gtest/gtest.h> |
| 9 | + |
| 10 | +#include <executorch/kernels/portable/cpu/util/tensor_util.h> |
| 11 | +#include <executorch/runtime/core/exec_aten/exec_aten.h> |
| 12 | +#include <executorch/runtime/platform/runtime.h> |
| 13 | +#include <array> |
| 14 | + |
| 15 | +using executorch::runtime::kTensorDimensionLimit; |
| 16 | +using executorch::runtime::Span; |
| 17 | +using executorch::runtime::tensor_shape_to_c_string; |
| 18 | +using executorch::runtime::internal::kMaximumPrintableTensorShapeElement; |
| 19 | + |
| 20 | +TEST(TensorUtilTest, TensorShapeToCStringBasic) { |
| 21 | + std::array<executorch::aten::SizesType, 3> sizes = {123, 456, 789}; |
| 22 | + auto str = tensor_shape_to_c_string( |
| 23 | + Span<const executorch::aten::SizesType>(sizes.data(), sizes.size())); |
| 24 | + EXPECT_STREQ(str.data(), "(123, 456, 789)"); |
| 25 | + |
| 26 | + std::array<executorch::aten::SizesType, 1> one_size = {1234567890}; |
| 27 | + str = tensor_shape_to_c_string(Span<const executorch::aten::SizesType>( |
| 28 | + one_size.data(), one_size.size())); |
| 29 | + EXPECT_STREQ(str.data(), "(1234567890)"); |
| 30 | +} |
| 31 | + |
| 32 | +TEST(TensorUtilTest, TensorShapeToCStringNegativeItems) { |
| 33 | + std::array<executorch::aten::SizesType, 4> sizes = {-1, -3, -2, 4}; |
| 34 | + auto str = tensor_shape_to_c_string( |
| 35 | + Span<const executorch::aten::SizesType>(sizes.data(), sizes.size())); |
| 36 | + EXPECT_STREQ(str.data(), "(ERR, ERR, ERR, 4)"); |
| 37 | + |
| 38 | + std::array<executorch::aten::SizesType, 1> one_size = {-1234567890}; |
| 39 | + str = tensor_shape_to_c_string(Span<const executorch::aten::SizesType>( |
| 40 | + one_size.data(), one_size.size())); |
| 41 | + if constexpr (std::numeric_limits<executorch::aten::SizesType>::is_signed) { |
| 42 | + EXPECT_STREQ(str.data(), "(ERR)"); |
| 43 | + } else { |
| 44 | + EXPECT_EQ(str.data(), "(" + std::to_string(one_size[0]) + ")"); |
| 45 | + } |
| 46 | +} |
| 47 | +TEST(TensorUtilTest, TensorShapeToCStringMaximumElement) { |
| 48 | + std::array<executorch::aten::SizesType, 3> sizes = { |
| 49 | + 123, std::numeric_limits<executorch::aten::SizesType>::max(), 789}; |
| 50 | + auto str = tensor_shape_to_c_string( |
| 51 | + Span<const executorch::aten::SizesType>(sizes.data(), sizes.size())); |
| 52 | + std::ostringstream expected; |
| 53 | + expected << '('; |
| 54 | + for (const auto elem : sizes) { |
| 55 | + expected << elem << ", "; |
| 56 | + } |
| 57 | + auto expected_str = expected.str(); |
| 58 | + expected_str.pop_back(); |
| 59 | + expected_str.back() = ')'; |
| 60 | + EXPECT_EQ(str.data(), expected_str); |
| 61 | +} |
| 62 | + |
| 63 | +TEST(TensorUtilTest, TensorShapeToCStringMaximumLength) { |
| 64 | + std::array<executorch::aten::SizesType, kTensorDimensionLimit> sizes; |
| 65 | + std::fill(sizes.begin(), sizes.end(), kMaximumPrintableTensorShapeElement); |
| 66 | + |
| 67 | + auto str = tensor_shape_to_c_string( |
| 68 | + Span<const executorch::aten::SizesType>(sizes.data(), sizes.size())); |
| 69 | + |
| 70 | + std::ostringstream expected; |
| 71 | + expected << '(' << kMaximumPrintableTensorShapeElement; |
| 72 | + for (int ii = 0; ii < kTensorDimensionLimit - 1; ++ii) { |
| 73 | + expected << ", " << kMaximumPrintableTensorShapeElement; |
| 74 | + } |
| 75 | + expected << ')'; |
| 76 | + auto expected_str = expected.str(); |
| 77 | + |
| 78 | + EXPECT_EQ(expected_str, str.data()); |
| 79 | +} |
| 80 | + |
| 81 | +TEST(TensorUtilTest, TensorShapeToCStringExceedsDimensionLimit) { |
| 82 | + std::array<executorch::aten::SizesType, kTensorDimensionLimit + 1> sizes; |
| 83 | + std::fill(sizes.begin(), sizes.end(), kMaximumPrintableTensorShapeElement); |
| 84 | + |
| 85 | + auto str = tensor_shape_to_c_string( |
| 86 | + Span<const executorch::aten::SizesType>(sizes.data(), sizes.size())); |
| 87 | + |
| 88 | + EXPECT_STREQ(str.data(), "(ERR: tensor ndim exceeds limit)"); |
| 89 | +} |
0 commit comments