From 9e82b267dbba5d8722f1689b2a6590ddd8aa0b4a Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Jun 2025 20:38:22 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../portable/cpu/op__to_dim_order_copy.cpp | 48 +++++-------------- kernels/portable/cpu/op_glu.cpp | 2 +- .../cpu/util/broadcast_indexes_range.h | 25 ++++++---- kernels/portable/cpu/util/elementwise_util.h | 10 ++-- .../kernels/portable/op_registration_util.bzl | 1 + 5 files changed, 34 insertions(+), 52 deletions(-) diff --git a/kernels/portable/cpu/op__to_dim_order_copy.cpp b/kernels/portable/cpu/op__to_dim_order_copy.cpp index 70fc3507f05..a1abf8a6a00 100644 --- a/kernels/portable/cpu/op__to_dim_order_copy.cpp +++ b/kernels/portable/cpu/op__to_dim_order_copy.cpp @@ -9,8 +9,8 @@ #include #include +#include #include -#include #include namespace torch { @@ -31,47 +31,21 @@ using Optional = executorch::aten::optional; namespace { -// TODO(T179241236): Update core/exec_aten/util/tensor_util.h to support dim -// order other than contiguous. -int64_t coordinateToIndexWithDimOrder( - const Tensor& self, - const size_t* cur_indices) { - int64_t index = 0; - executorch::aten::StridesType strides[kTensorDimensionLimit]; - SizesArrayRef sizes = self.sizes(); - DimOrderArrayRef dim_order = self.dim_order(); - - dim_order_to_stride_nocheck( - sizes.data(), dim_order.data(), sizes.size(), strides); - for (const auto i : c10::irange(self.dim())) { - index += cur_indices[i] * strides[i]; - } - return index; -} - template void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) { auto self_data = self.mutable_data_ptr(); auto out_data = out.mutable_data_ptr(); - size_t coordinate[kTensorDimensionLimit] = {0}; - - // Copy data from self to out index by index. Same index in self and out - // should have same value, no matter the order of dimensions. - for (ssize_t i = 0; i < self.numel(); i++) { - // Update the current indices. - for (ssize_t j = self.dim() - 1; j >= 0; j--) { - if (coordinate[j] + 1 < static_cast(self.size(j))) { - coordinate[j]++; - break; - } else { - coordinate[j] = 0; - } - } - // Get the corresponding index of self_data and out_data by stride. - int64_t self_data_index = coordinateToIndexWithDimOrder(self, coordinate); - int64_t out_data_index = coordinateToIndexWithDimOrder(out, coordinate); - + // Here we make a slightly off-label use of + // BroadcastIndexesRange. It always assumes it doesn't have to care + // about different dim_order between input and output, but we can + // just force it to respect strides (and thus dim_order) for its + // inputs using support_noncontiguous_input_tensors=true, and then pretend + // the output is just another input. + for (const auto [unused_index, self_data_index, out_data_index] : + BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>( + /*dummy output*/ self, self, out)) { + (void)unused_index; out_data[out_data_index] = static_cast(self_data[self_data_index]); } diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index 3ac04f087ab..f204b0fd516 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -110,7 +110,7 @@ Tensor& glu_out_tensor( split_input.second_half, utils::SupportedTensorDtypes::FLOATHBF16, out, - utils::internal::SupportNoncontiguousTensors()); + utils::internal::SupportNoncontiguousInputTensors()); }); return out; } diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index d372767819a..7434748d505 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -43,7 +43,9 @@ inline bool sizes_match_ignoring_leading_1s( std::equal(lhs_begin, lhs_end, rhs_begin); } -template +template < + std::size_t kNumInputs, + bool support_noncontiguous_input_tensors = false> class BroadcastIndexesIterator { public: using difference_type = ssize_t; @@ -57,7 +59,7 @@ class BroadcastIndexesIterator { template explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args) : output_dim_or_zero_if_no_broadcasting_( - !support_noncontiguous_tensors && + !support_noncontiguous_input_tensors && (sizes_match_ignoring_leading_1s( args.sizes(), output.sizes()) && @@ -69,7 +71,7 @@ class BroadcastIndexesIterator { sizeof...(args) == kNumInputs && (std::is_same_v && ...), "BroadcastIndexesIterator constructor requires kNumInputs input tensor" "arguments!"); - if (support_noncontiguous_tensors || + if (support_noncontiguous_input_tensors || output_dim_or_zero_if_no_broadcasting_ != 0) { effective_input_broadcast_strides_ = { effective_input_broadcast_stride(output, args)...}; @@ -254,16 +256,21 @@ class BroadcastIndexesIterator { * linearize_access_indexes(), BroadcastIndexesRange avoids expensive * division and modulo operations on each iteration. * - * The support_noncontiguous_tensors argument disables an optimization - * that causes the iterators not to respect strides in some - * cases. This optimization is normally safe because ExecuTorch - * tensors are contiguous. + * The support_noncontiguous_input_tensors argument disables an + * optimization that causes the iterators not to respect strides in + * some cases for input tensors. This optimization is normally safe + * because ExecuTorch tensors are contiguous. Non-contiguous output + * tensors are currently never supported (but note that this can be + * worked around by ignoring the output index and providing the true + * output as an extra input). */ -template +template < + std::size_t kNumInputs, + bool support_noncontiguous_input_tensors = false> class BroadcastIndexesRange { public: using iterator = internal:: - BroadcastIndexesIterator; + BroadcastIndexesIterator; template BroadcastIndexesRange(const Tensor& output, const Args&... args) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 722483ec363..d07250f1d66 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -56,8 +56,8 @@ namespace internal { * strides; normally, this is not strictly necessary because ExecuTorch * Tensors are contiguous. */ -struct SupportNoncontiguousTensors { - explicit SupportNoncontiguousTensors() = default; +struct SupportNoncontiguousInputTensors { + explicit SupportNoncontiguousInputTensors() = default; }; template < @@ -292,7 +292,7 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out, - SupportNoncontiguousTensors) { + SupportNoncontiguousInputTensors) { internal::apply_elementwise_fn< CTYPE_COMPUTE, op_name, @@ -366,7 +366,7 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out, - SupportNoncontiguousTensors) { + SupportNoncontiguousInputTensors) { internal::apply_elementwise_fn< CTYPE_COMPUTE, op_name, @@ -467,7 +467,7 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out, - SupportNoncontiguousTensors) { + SupportNoncontiguousInputTensors) { internal::apply_elementwise_fn< CTYPE_COMPUTE, op_name, diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index 96941590dd4..1ae20ca7c61 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1324,6 +1324,7 @@ ATEN_OPS = ( name = "op__to_dim_order_copy", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:copy_ops_util", ], ),