Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions runtime/core/exec_aten/util/dim_order_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,63 @@ bool is_channels_last_dim_order(
return true;
}

/**
* Determines whether a tensor can be interpreted as both channels_last and
* contiguous memory formats without any issues in memory access.
*
* When certain dimensions are of size 1, the stride along those dimensions
* doesn't impact the memory layout, making the tensor's data layout effectively
* the same in both memory formats.
*
* Specifically, if the tensor's shape satisfies certain conditions (e.g., the
* channel dimension C is 1, or both spatial dimensions H and W are 1), the
* tensor can be safely interpreted under both memory formats without causing
* inconsistencies in memory access.
*
* Note:
* This is a temporary function because the current dim_order cannot explicitly
* specify the correct memory format's dimension order. Once we resolve the
* ambiguous dimension order issue, this check will be removed.
*
* @param[in] shape A pointer to an array representing the tensor's shape.
* @param[in] dim The number of dimensions (length of the shape array).
* @return True if the tensor can be interpreted as both formats without memory
* access issues; False otherwise.
*/
template <typename SizesType>
bool can_be_interpreted_as_channels_last_and_contiguous(
const SizesType* shape,
const size_t dim) {
// Check if the tensor is 4-dimensional
if (dim != 4) {
return false;
}

// Extract dimensions: N (batch size), C (channels), H (height), W (width)
size_t C = shape[1];
size_t H = shape[2];
size_t W = shape[3];

// Condition 1: If the number of channels C is 1
if (C == 1) {
return true;
}

// Condition 2: If both spatial dimensions H and W are 1
if (H == 1 && W == 1) {
return true;
}

// Condition 3: If either H or W is 1, and C is also 1
if ((H == 1 || W == 1) && C == 1) {
return true;
}

// If none of the above conditions are met, it cannot be interpreted as both
// formats
return false;
}

/*
* This utility translated sizes to strides by using dimension order
* information. Dimension order specifies how the dimensions are laid out in the
Expand Down
25 changes: 17 additions & 8 deletions runtime/core/exec_aten/util/tensor_util_portable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ bool tensor_is_default_or_channels_last_dim_order(torch::executor::Tensor t) {

bool tensor_is_default_dim_order(torch::executor::Tensor t) {
bool ret_val =
is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size());
is_contiguous_dim_order(t.dim_order().data(), t.dim_order().size()) ||
can_be_interpreted_as_channels_last_and_contiguous(
t.sizes().data(), t.dim());

if (!ret_val) {
ET_LOG(Error, "Expected tensor to have default dim order, but got");
Expand All @@ -92,7 +94,9 @@ bool tensor_is_default_dim_order(torch::executor::Tensor t) {

bool tensor_is_channels_last_dim_order(torch::executor::Tensor t) {
bool ret_val =
is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size());
is_channels_last_dim_order(t.dim_order().data(), t.dim_order().size()) ||
can_be_interpreted_as_channels_last_and_contiguous(
t.sizes().data(), t.dim());

if (!ret_val) {
ET_LOG(Error, "Expected tensor to have channels last dim order, but got");
Expand All @@ -116,13 +120,18 @@ bool tensors_have_same_dim_order(
bool all_channels_last = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
all_contiguous = all_contiguous &&
is_contiguous_dim_order(
tensor_list[i].dim_order().data(),
tensor_list[i].dim_order().size());
(is_contiguous_dim_order(
tensor_list[i].dim_order().data(),
tensor_list[i].dim_order().size()) ||
can_be_interpreted_as_channels_last_and_contiguous(
tensor_list[i].sizes().data(), tensor_list[i].dim()));

all_channels_last = all_channels_last &&
is_channels_last_dim_order(
tensor_list[i].dim_order().data(),
tensor_list[i].dim_order().size());
(is_channels_last_dim_order(
tensor_list[i].dim_order().data(),
tensor_list[i].dim_order().size()) ||
can_be_interpreted_as_channels_last_and_contiguous(
tensor_list[i].sizes().data(), tensor_list[i].dim()));
}

ET_LOG_MSG_AND_RETURN_IF_FALSE(
Expand Down
41 changes: 41 additions & 0 deletions runtime/core/exec_aten/util/test/dim_order_util_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using executorch::runtime::Error;
using executorch::runtime::is_channels_last_dim_order;
using executorch::runtime::is_contiguous_dim_order;
using executorch::runtime::stride_to_dim_order;
using executorch::runtime::can_be_interpreted_as_channels_last_and_contiguous;

namespace {
void check_strides_eq(
Expand Down Expand Up @@ -286,3 +287,43 @@ TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
EXPECT_FALSE(is_channels_last_dim_order(dim_order_4d, 4));
EXPECT_FALSE(is_channels_last_dim_order(dim_order_5d, 5));
}


TEST(DimOrderUtilTest, MultiMemoryFormatInterpretSuccess) {
// Test that tensors having following sizes should be interruptable to both and contiguous memory format
exec_aten::SizesType sizes[][4] = {
{2, 1, 2, 2},
{2, 2, 1, 1},
{1, 2, 1, 1},
{2, 1, 1, 2},
{2, 1, 2, 1},
{1, 1, 1, 2},
{1, 1, 2, 2},
{1, 1, 1, 1},
{2, 1, 1, 1},
{1, 1, 2, 1}
};

ssize_t n_testcases = sizeof(sizes) / sizeof(sizes[0]);

for (size_t i = 0; i < n_testcases; i++) {
EXPECT_TRUE(can_be_interpreted_as_channels_last_and_contiguous(sizes[i], 4));
}
}

TEST(DimOrderUtilTest, MultiMemoryFormatInterpretFail) {
// Test that tensors having following sizes should only fall into single memory format.
exec_aten::SizesType sizes[][4] = {
{1, 2, 1, 2},
{1, 2, 2, 1},
{2, 2, 1, 2},
{2, 2, 2, 2},
{1, 2, 2, 2},
{2, 2, 2, 1}
};

ssize_t n_testcases = sizeof(sizes) / sizeof(sizes[0]);
for (size_t i = 0; i < n_testcases; i++) {
EXPECT_FALSE(can_be_interpreted_as_channels_last_and_contiguous(sizes[i], 4));
}
}
Loading