diff --git a/kernels/portable/cpu/util/advanced_index_util.cpp b/kernels/portable/cpu/util/advanced_index_util.cpp index e2eabec4bc7..3022ec61b93 100644 --- a/kernels/portable/cpu/util/advanced_index_util.cpp +++ b/kernels/portable/cpu/util/advanced_index_util.cpp @@ -7,6 +7,7 @@ */ #include +#include #include namespace torch { @@ -49,9 +50,22 @@ bool check_mask_indices(const Tensor& in, TensorOptList indices) { ET_LOG_MSG_AND_RETURN_IF_FALSE( index.dim() > 0, "Zero-dimensional mask index not allowed"); for (auto j = 0; j < index.dim(); j++) { - ET_LOG_MSG_AND_RETURN_IF_FALSE( - index.size(j) == in.size(in_i + j), - "The shape of mask index must match the sizes of the corresponding input dimensions."); + if (index.size(j) != in.size(in_i + j)) { +#ifdef ET_LOG_ENABLED + auto mask_shape = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + index.sizes().data(), index.sizes().size())); + auto input_shape = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + in.sizes().data() + in_i, index.sizes().size())); + ET_LOG( + Error, + "The shape of mask index %s must match the sizes of the corresponding input dimensions %s.", + mask_shape.data(), + input_shape.data()); +#endif // ET_LOG_ENABLED + return false; + } } in_i += index.dim(); } else { diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 0115feb6256..2c25d171568 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -117,6 +117,7 @@ def define_common_targets(): compiler_flags = ["-Wno-missing-prototypes"], deps = [ ":broadcast_util", + "//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],