From 0f3920565267fabc6e169a86f39e572a68faf1d3 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 9 Jun 2025 17:17:33 -0700 Subject: [PATCH 1/3] Reapply #11294 and #11295 (improve GLU test and implement using internal views to avoid copying) These were reverted due to internal test failures. Sending this as an exported internal diff so that we can make sure we get internal signal. Original summary for #11294 (to make the GLU test input asymmetric): This way it will produce different results along each tested dim. Original summaryfor #11295: GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter. Differential Revision: [D76311585](https://our.internmc.facebook.com/intern/diff/D76311585/) [ghstack-poisoned] --- kernels/portable/cpu/op_glu.cpp | 150 +++++++----------- .../cpu/util/broadcast_indexes_range.h | 22 ++- kernels/portable/cpu/util/elementwise_util.h | 127 +++++++++++++-- kernels/test/op_glu_test.cpp | 15 +- .../kernels/portable/op_registration_util.bzl | 1 + 5 files changed, 201 insertions(+), 114 deletions(-) diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index edc82c55eb8..be76a158182 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType; namespace { -double exp_overload(double d) { - return exp(d); -} - -float exp_overload(float f) { - return expf(f); -} - -/** - * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x}) - */ -// TODO: T146333648, refactor this as a common helper function -template -void sigmoid_tensor(Tensor& out) { - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i])); - } -} - -/** - * Element-wise multiplication of the first half of `in` along the specified - * dimension and `out`, overwriting `out`. - */ -template -void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) { - size_t num_values = static_cast(in.size(dim)) / 2; - size_t dim_length_in = static_cast(in.size(dim)); - size_t dim_length_out = static_cast(out.size(dim)); - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - - const CTYPE_IN* input_data_base = in.const_data_ptr(); - CTYPE_OUT* output_data_base = out.mutable_data_ptr(); - - for (const auto i : c10::irange(leading_dims)) { - const CTYPE_IN* input_data = - input_data_base + i * dim_length_in * trailing_dims; - CTYPE_OUT* output_data = - output_data_base + i * dim_length_out * trailing_dims; - for ([[maybe_unused]] const auto j : c10::irange(num_values)) { - for (const auto k : c10::irange(trailing_dims)) { - output_data[k] = static_cast(input_data[k]) * output_data[k]; - } - input_data += trailing_dims; - output_data += trailing_dims; - } - } -} - -/** - * Slice the tensor in the given dim, from start to end, assume tensor in and - * out have same shape and dtype, the dim is a non-negative number and start, - * end are valid non-negative number - */ -template -void slice_tensor( - const Tensor& in, - int64_t dim, - int64_t start, - int64_t end, - Tensor& out) { - size_t num_values = static_cast(end - start); - size_t dim_length_in = static_cast(in.size(dim)); - size_t dim_length_out = static_cast(out.size(dim)); - size_t non_negative_start = static_cast(start); - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - - const CTYPE_IN* input_data_base = in.const_data_ptr(); - CTYPE_OUT* output_data_base = out.mutable_data_ptr(); - - for (const auto i : c10::irange(leading_dims)) { - const CTYPE_IN* input_data = input_data_base + - (i * dim_length_in + non_negative_start) * trailing_dims; - CTYPE_OUT* output_data = - output_data_base + i * dim_length_out * trailing_dims; - for ([[maybe_unused]] const auto j : c10::irange(num_values)) { - for (const auto k : c10::irange(trailing_dims)) { - output_data[k] = static_cast(input_data[k]); - } - input_data += trailing_dims; - output_data += trailing_dims; - } - } -} - /** * Applies the gated linear unit function * @@ -120,11 +34,63 @@ void slice_tensor( * 2. The output shall be in float types (Float, Double) */ template -Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) { +Tensor& glu_out_tensor( + KernelRuntimeContext& ctx, + const Tensor& self, + int64_t dim, + Tensor& out) { const auto self_size = self.size(dim); - slice_tensor(self, dim, self_size / 2, self_size, out); - sigmoid_tensor(out); - mul_tensors(self, dim, out); + ET_KERNEL_CHECK( + ctx, + self.dim() <= static_cast(kTensorDimensionLimit), + InvalidArgument, + out); + std::array half_sizes; + std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin()); + half_sizes[dim] /= 2; + TensorImpl first_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + self.mutable_data_ptr(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()); + TensorImpl second_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + reinterpret_cast(self.mutable_data_ptr()) + + self.strides()[dim] * self_size / 2 * self.element_size(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()); + Tensor first_half(&first_half_impl); + Tensor second_half(&second_half_impl); + ScalarType compute_type = + executorch::runtime::isFloatingType(self.scalar_type()) + ? self.scalar_type() + : ScalarType::Float; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "glu.out"; + ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE { + // TODO: rewrite this to be vectorization-capable. + const auto one = static_cast(1.0); + return val_a * (one / (one + std::exp(-val_b))); + }, + ctx, + first_half, + utils::SupportedTensorDtypes::FLOATHBF16, + second_half, + utils::SupportedTensorDtypes::FLOATHBF16, + out, + utils::internal::SupportNoncontiguousTensors()); + }); return out; } } // namespace @@ -158,7 +124,7 @@ Tensor& glu_out( ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() { - glu_out_tensor(self, non_negative_dim, out); + glu_out_tensor(ctx, self, non_negative_dim, out); }); }); diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index 4d3ba46b51b..d372767819a 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s( std::equal(lhs_begin, lhs_end, rhs_begin); } -template +template class BroadcastIndexesIterator { public: using difference_type = ssize_t; @@ -57,8 +57,11 @@ class BroadcastIndexesIterator { template explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args) : output_dim_or_zero_if_no_broadcasting_( - (sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) && - ...) + !support_noncontiguous_tensors && + (sizes_match_ignoring_leading_1s( + args.sizes(), + output.sizes()) && + ...) ? 0 : output.dim()), output_shape_(output.sizes()) { @@ -66,7 +69,8 @@ class BroadcastIndexesIterator { sizeof...(args) == kNumInputs && (std::is_same_v && ...), "BroadcastIndexesIterator constructor requires kNumInputs input tensor" "arguments!"); - if (output_dim_or_zero_if_no_broadcasting_ != 0) { + if (support_noncontiguous_tensors || + output_dim_or_zero_if_no_broadcasting_ != 0) { effective_input_broadcast_strides_ = { effective_input_broadcast_stride(output, args)...}; } @@ -249,11 +253,17 @@ class BroadcastIndexesIterator { * Unlike looping using delinearize_index() and * linearize_access_indexes(), BroadcastIndexesRange avoids expensive * division and modulo operations on each iteration. + * + * The support_noncontiguous_tensors argument disables an optimization + * that causes the iterators not to respect strides in some + * cases. This optimization is normally safe because ExecuTorch + * tensors are contiguous. */ -template +template class BroadcastIndexesRange { public: - using iterator = internal::BroadcastIndexesIterator; + using iterator = internal:: + BroadcastIndexesIterator; template BroadcastIndexesRange(const Tensor& output, const Args&... args) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index e30b8af7d89..722483ec363 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,9 +51,19 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { +/** + * Causes these utility functions to make sure to respect Tensor + * strides; normally, this is not strictly necessary because ExecuTorch + * Tensors are contiguous. + */ +struct SupportNoncontiguousTensors { + explicit SupportNoncontiguousTensors() = default; +}; + template < typename CTYPE_COMPUTE, typename CTYPE_OUT, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void dtype_specialized_elementwise_fn_impl( @@ -75,7 +85,8 @@ inline void dtype_specialized_elementwise_fn_impl( CTYPE_OUT* const data_out = out.mutable_data_ptr(); const auto range = - BroadcastIndexesRange(out, (*inputs.first)...); + BroadcastIndexesRange( + out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -117,6 +128,7 @@ inline bool validate_elementwise_fn_inputs( template < typename CTYPE_COMPUTE, const char* op_name, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn_generic_impl( @@ -151,7 +163,8 @@ inline void apply_elementwise_fn_generic_impl( ::executorch::extension::internal::GRAIN_SIZE, [&](const auto begin, const auto end) { const auto range = - BroadcastIndexesRange(out, (*inputs.first)...); + BroadcastIndexesRange( + out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -187,7 +200,10 @@ inline void apply_elementwise_fn_runtime_out_dtypes( return; } - apply_elementwise_fn_generic_impl( + apply_elementwise_fn_generic_impl< + CTYPE_COMPUTE, + op_name, + /*support_noncontiguous_tensors*/ false>( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -195,6 +211,7 @@ template < typename CTYPE_COMPUTE, const char* op_name, SupportedTensorDtypes out_dtypes, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn( @@ -218,12 +235,17 @@ inline void apply_elementwise_fn( out.scalar_type() == out_specialized_scalar_type) { using CTYPE_OUT = typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl( - compute_fun, ctx, out, inputs...); + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); return; } - apply_elementwise_fn_generic_impl( + apply_elementwise_fn_generic_impl< + CTYPE_COMPUTE, + op_name, + support_noncontiguous_tensors>( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -251,7 +273,31 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_unitensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); } @@ -295,7 +341,37 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, + ctx, + out, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_bitensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, @@ -363,7 +439,40 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, + ctx, + out, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes), + std::make_pair(&c, c_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_tritensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& c, + SupportedTensorDtypes c_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index f8bf22dae63..ac931302f98 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -28,9 +28,10 @@ class OpGluOutTest : public OperatorTest { return torch::executor::aten::glu_outf(context_, self, dim, out); } - template + template void expect_tensor_close(Tensor actual, Tensor expected) { - if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16 || + OUT_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::BFloat16) { EXPECT_TENSOR_CLOSE_WITH_TOL( actual, expected, @@ -51,20 +52,20 @@ class OpGluOutTest : public OperatorTest { const std::vector out_sizes_1 = {2, 2}; // Valid input should give the expected output - Tensor in = tf.ones(sizes); + Tensor in = tf.make(sizes, {0, 1, 2, 3, 4, 5, 6, 7}); Tensor out = tf_out.zeros(out_sizes_1); op_glu_out(in, 0, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( - out_sizes_1, /*data=*/{0.731059, 0.731059, 0.731059, 0.731059})); + out_sizes_1, /*data=*/{0, 0.99330717, 1.99505484, 2.99726701})); const std::vector out_sizes_2 = {4, 1}; out = tf_out.zeros(out_sizes_2); op_glu_out(in, 1, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( - out_sizes_2, /*data=*/{0.731059, 0.731059, 0.731059, 0.731059})); + out_sizes_2, /*data=*/{0, 1.90514827, 3.97322869, 5.99453402})); } // Mismatched shape tests. diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index a731ce5c674..96941590dd4 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -618,6 +618,7 @@ ATEN_OPS = ( name = "op_glu", deps = [ "//executorch/kernels/portable/cpu/util:activation_ops_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", ], From c4b0b45a02be7c2e31356cfa5ef8ff6236f9fd5b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Jun 2025 09:50:53 -0700 Subject: [PATCH 2/3] save a little size on "Reapply #11294 and #11295 (improve GLU test and implement using internal views to avoid copying)" These were reverted due to internal test failures. Sending this as an exported internal diff so that we can make sure we get internal signal. Original summary for #11294 (to make the GLU test input asymmetric): This way it will produce different results along each tested dim. Original summaryfor #11295: GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter. Differential Revision: [D76311585](https://our.internmc.facebook.com/intern/diff/D76311585/) [ghstack-poisoned] --- kernels/portable/cpu/op_glu.cpp | 72 +++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index be76a158182..672159d86f2 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -24,6 +24,46 @@ using ScalarType = executorch::aten::ScalarType; namespace { +struct SplitGLUInputTensor { + explicit SplitGLUInputTensor(const Tensor& self, int64_t dim); + using SizesArray = std::array; + SizesArray half_sizes; + TensorImpl first_half_impl; + TensorImpl second_half_impl; + Tensor first_half; + Tensor second_half; + + private: + static SizesArray get_half_sizes(const Tensor& self, int64_t dim) { + SizesArray half_sizes; + std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin()); + half_sizes[dim] /= 2; + return half_sizes; + } +}; + +SplitGLUInputTensor::SplitGLUInputTensor(const Tensor& self, int64_t dim) + : half_sizes(get_half_sizes(self, dim)), + first_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + self.mutable_data_ptr(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()), + second_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + reinterpret_cast(self.mutable_data_ptr()) + + self.strides()[dim] * self.size(dim) / 2 * self.element_size(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()), + first_half(&first_half_impl), + second_half(&second_half_impl) {} + /** * Applies the gated linear unit function * @@ -39,34 +79,12 @@ Tensor& glu_out_tensor( const Tensor& self, int64_t dim, Tensor& out) { - const auto self_size = self.size(dim); ET_KERNEL_CHECK( ctx, self.dim() <= static_cast(kTensorDimensionLimit), InvalidArgument, out); - std::array half_sizes; - std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin()); - half_sizes[dim] /= 2; - TensorImpl first_half_impl( - self.scalar_type(), - self.dim(), - half_sizes.data(), - self.mutable_data_ptr(), - const_cast(self.dim_order().data()), - const_cast(self.strides().data()), - self.shape_dynamism()); - TensorImpl second_half_impl( - self.scalar_type(), - self.dim(), - half_sizes.data(), - reinterpret_cast(self.mutable_data_ptr()) + - self.strides()[dim] * self_size / 2 * self.element_size(), - const_cast(self.dim_order().data()), - const_cast(self.strides().data()), - self.shape_dynamism()); - Tensor first_half(&first_half_impl); - Tensor second_half(&second_half_impl); + SplitGLUInputTensor split_input(self, dim); ScalarType compute_type = executorch::runtime::isFloatingType(self.scalar_type()) ? self.scalar_type() @@ -79,14 +97,16 @@ Tensor& glu_out_tensor( op_name, utils::SupportedTensorDtypes::FLOATHBF16>( [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE { - // TODO: rewrite this to be vectorization-capable. + // TODO: rewrite this to be vectorization-capable? the + // tensors might not be contiguous; need to have + // apply_bitensor_elementwise_fn check that. const auto one = static_cast(1.0); return val_a * (one / (one + std::exp(-val_b))); }, ctx, - first_half, + split_input.first_half, utils::SupportedTensorDtypes::FLOATHBF16, - second_half, + split_input.second_half, utils::SupportedTensorDtypes::FLOATHBF16, out, utils::internal::SupportNoncontiguousTensors()); From 3d3933261ae97763640f8c4d1d56faeaeae04fc0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 10 Jun 2025 10:01:04 -0700 Subject: [PATCH 3/3] format on "Reapply #11294 and #11295 (improve GLU test and implement using internal views to avoid copying)" These were reverted due to internal test failures. Sending this as an exported internal diff so that we can make sure we get internal signal. Original summary for #11294 (to make the GLU test input asymmetric): This way it will produce different results along each tested dim. Original summaryfor #11295: GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter. Differential Revision: [D76311585](https://our.internmc.facebook.com/intern/diff/D76311585/) [ghstack-poisoned] --- kernels/portable/cpu/op_glu.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index 672159d86f2..3ac04f087ab 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -26,7 +26,8 @@ namespace { struct SplitGLUInputTensor { explicit SplitGLUInputTensor(const Tensor& self, int64_t dim); - using SizesArray = std::array; + using SizesArray = + std::array; SizesArray half_sizes; TensorImpl first_half_impl; TensorImpl second_half_impl; @@ -57,7 +58,7 @@ SplitGLUInputTensor::SplitGLUInputTensor(const Tensor& self, int64_t dim) self.dim(), half_sizes.data(), reinterpret_cast(self.mutable_data_ptr()) + - self.strides()[dim] * self.size(dim) / 2 * self.element_size(), + self.strides()[dim] * self.size(dim) / 2 * self.element_size(), const_cast(self.dim_order().data()), const_cast(self.strides().data()), self.shape_dynamism()),