diff --git a/runtime/core/exec_aten/util/dim_order_util.h b/runtime/core/exec_aten/util/dim_order_util.h index 4c5858fb5e9..7acde87f2e8 100644 --- a/runtime/core/exec_aten/util/dim_order_util.h +++ b/runtime/core/exec_aten/util/dim_order_util.h @@ -82,6 +82,63 @@ bool is_channels_last_dim_order( return true; } +/** + * Determines whether a tensor can be interpreted as both channels_last and + * contiguous memory formats without any issues in memory access. + * + * When certain dimensions are of size 1, the stride along those dimensions + * doesn't impact the memory layout, making the tensor's data layout effectively + * the same in both memory formats. + * + * Specifically, if the tensor's shape satisfies certain conditions (e.g., the + * channel dimension C is 1, or both spatial dimensions H and W are 1), the + * tensor can be safely interpreted under both memory formats without causing + * inconsistencies in memory access. + * + * Note: + * This is a temporary function because the current dim_order cannot explicitly + * specify the correct memory format's dimension order. Once we resolve the + * ambiguous dimension order issue, this check will be removed. + * + * @param[in] shape A pointer to an array representing the tensor's shape. + * @param[in] dim The number of dimensions (length of the shape array). + * @return True if the tensor can be interpreted as both formats without memory + * access issues; False otherwise. + */ +template +bool can_be_interpreted_as_channels_last_and_contiguous( + const SizesType* shape, + const size_t dim) { + // Check if the tensor is 4-dimensional + if (dim != 4) { + return false; + } + + // Extract dimensions: N (batch size), C (channels), H (height), W (width) + size_t C = shape[1]; + size_t H = shape[2]; + size_t W = shape[3]; + + // Condition 1: If the number of channels C is 1 + if (C == 1) { + return true; + } + + // Condition 2: If both spatial dimensions H and W are 1 + if (H == 1 && W == 1) { + return true; + } + + // Condition 3: If either H or W is 1, and C is also 1 + if ((H == 1 || W == 1) && C == 1) { + return true; + } + + // If none of the above conditions are met, it cannot be interpreted as both + // formats + return false; +} + /* * This utility translated sizes to strides by using dimension order * information. Dimension order specifies how the dimensions are laid out in the diff --git a/runtime/core/exec_aten/util/tensor_util_portable.cpp b/runtime/core/exec_aten/util/tensor_util_portable.cpp index b7ed92f3c97..4718229baf5 100644 --- a/runtime/core/exec_aten/util/tensor_util_portable.cpp +++ b/runtime/core/exec_aten/util/tensor_util_portable.cpp @@ -75,7 +75,9 @@ bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) { bool tensor_is_default_dim_order(torch::executor::Tensor t) { bool ret_val = - is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size()); + is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size()) || + can_be_interpreted_as_channels_last_and_contiguous( + t.sizes().data(), t.dim()); if (!ret_val) { ET_LOG(Error, "Expected tensor to have default dim order, but got"); @@ -92,7 +94,9 @@ bool tensor_is_default_dim_order(torch::executor::Tensor t) { bool tensor_is_channels_last_dim_order(torch::executor::Tensor t) { bool ret_val = - is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size()); + is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size()) || + can_be_interpreted_as_channels_last_and_contiguous( + t.sizes().data(), t.dim()); if (!ret_val) { ET_LOG(Error, "Expected tensor to have channels last dim order, but got"); @@ -116,13 +120,18 @@ bool tensors_have_same_dim_order( bool all_channels_last = true; for (size_t i = 0; i < tensor_list.size(); ++i) { all_contiguous = all_contiguous && - is_contiguous_dim_order( - tensor_list[i].dim_order().data(), - tensor_list[i].dim_order().size()); + (is_contiguous_dim_order( + tensor_list[i].dim_order().data(), + tensor_list[i].dim_order().size()) || + can_be_interpreted_as_channels_last_and_contiguous( + tensor_list[i].sizes().data(), tensor_list[i].dim())); + all_channels_last = all_channels_last && - is_channels_last_dim_order( - tensor_list[i].dim_order().data(), - tensor_list[i].dim_order().size()); + (is_channels_last_dim_order( + tensor_list[i].dim_order().data(), + tensor_list[i].dim_order().size()) || + can_be_interpreted_as_channels_last_and_contiguous( + tensor_list[i].sizes().data(), tensor_list[i].dim())); } ET_LOG_MSG_AND_RETURN_IF_FALSE( diff --git a/runtime/core/exec_aten/util/test/dim_order_util_test.cpp b/runtime/core/exec_aten/util/test/dim_order_util_test.cpp index 6ce611c9266..58f73452356 100644 --- a/runtime/core/exec_aten/util/test/dim_order_util_test.cpp +++ b/runtime/core/exec_aten/util/test/dim_order_util_test.cpp @@ -19,6 +19,7 @@ using executorch::runtime::Error; using executorch::runtime::is_channels_last_dim_order; using executorch::runtime::is_contiguous_dim_order; using executorch::runtime::stride_to_dim_order; +using executorch::runtime::can_be_interpreted_as_channels_last_and_contiguous; namespace { void check_strides_eq( @@ -286,3 +287,43 @@ TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) { EXPECT_FALSE(is_channels_last_dim_order(dim_order_4d, 4)); EXPECT_FALSE(is_channels_last_dim_order(dim_order_5d, 5)); } + + +TEST(DimOrderUtilTest, MultiMemoryFormatInterpretSuccess) { + // Test that tensors having following sizes should be interruptable to both and contiguous memory format + exec_aten::SizesType sizes[][4] = { + {2, 1, 2, 2}, + {2, 2, 1, 1}, + {1, 2, 1, 1}, + {2, 1, 1, 2}, + {2, 1, 2, 1}, + {1, 1, 1, 2}, + {1, 1, 2, 2}, + {1, 1, 1, 1}, + {2, 1, 1, 1}, + {1, 1, 2, 1} + }; + + ssize_t n_testcases = sizeof(sizes) / sizeof(sizes[0]); + + for (size_t i = 0; i < n_testcases; i++) { + EXPECT_TRUE(can_be_interpreted_as_channels_last_and_contiguous(sizes[i], 4)); + } +} + +TEST(DimOrderUtilTest, MultiMemoryFormatInterpretFail) { + // Test that tensors having following sizes should only fall into single memory format. + exec_aten::SizesType sizes[][4] = { + {1, 2, 1, 2}, + {1, 2, 2, 1}, + {2, 2, 1, 2}, + {2, 2, 2, 2}, + {1, 2, 2, 2}, + {2, 2, 2, 1} + }; + + ssize_t n_testcases = sizeof(sizes) / sizeof(sizes[0]); + for (size_t i = 0; i < n_testcases; i++) { + EXPECT_FALSE(can_be_interpreted_as_channels_last_and_contiguous(sizes[i], 4)); + } +}