diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index be76a158182..edc82c55eb8 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include @@ -24,6 +23,93 @@ using ScalarType = executorch::aten::ScalarType; namespace { +double exp_overload(double d) { + return exp(d); +} + +float exp_overload(float f) { + return expf(f); +} + +/** + * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x}) + */ +// TODO: T146333648, refactor this as a common helper function +template +void sigmoid_tensor(Tensor& out) { + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (const auto i : c10::irange(out.numel())) { + out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i])); + } +} + +/** + * Element-wise multiplication of the first half of `in` along the specified + * dimension and `out`, overwriting `out`. + */ +template +void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) { + size_t num_values = static_cast(in.size(dim)) / 2; + size_t dim_length_in = static_cast(in.size(dim)); + size_t dim_length_out = static_cast(out.size(dim)); + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + + const CTYPE_IN* input_data_base = in.const_data_ptr(); + CTYPE_OUT* output_data_base = out.mutable_data_ptr(); + + for (const auto i : c10::irange(leading_dims)) { + const CTYPE_IN* input_data = + input_data_base + i * dim_length_in * trailing_dims; + CTYPE_OUT* output_data = + output_data_base + i * dim_length_out * trailing_dims; + for ([[maybe_unused]] const auto j : c10::irange(num_values)) { + for (const auto k : c10::irange(trailing_dims)) { + output_data[k] = static_cast(input_data[k]) * output_data[k]; + } + input_data += trailing_dims; + output_data += trailing_dims; + } + } +} + +/** + * Slice the tensor in the given dim, from start to end, assume tensor in and + * out have same shape and dtype, the dim is a non-negative number and start, + * end are valid non-negative number + */ +template +void slice_tensor( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t end, + Tensor& out) { + size_t num_values = static_cast(end - start); + size_t dim_length_in = static_cast(in.size(dim)); + size_t dim_length_out = static_cast(out.size(dim)); + size_t non_negative_start = static_cast(start); + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + + const CTYPE_IN* input_data_base = in.const_data_ptr(); + CTYPE_OUT* output_data_base = out.mutable_data_ptr(); + + for (const auto i : c10::irange(leading_dims)) { + const CTYPE_IN* input_data = input_data_base + + (i * dim_length_in + non_negative_start) * trailing_dims; + CTYPE_OUT* output_data = + output_data_base + i * dim_length_out * trailing_dims; + for ([[maybe_unused]] const auto j : c10::irange(num_values)) { + for (const auto k : c10::irange(trailing_dims)) { + output_data[k] = static_cast(input_data[k]); + } + input_data += trailing_dims; + output_data += trailing_dims; + } + } +} + /** * Applies the gated linear unit function * @@ -34,63 +120,11 @@ namespace { * 2. The output shall be in float types (Float, Double) */ template -Tensor& glu_out_tensor( - KernelRuntimeContext& ctx, - const Tensor& self, - int64_t dim, - Tensor& out) { +Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) { const auto self_size = self.size(dim); - ET_KERNEL_CHECK( - ctx, - self.dim() <= static_cast(kTensorDimensionLimit), - InvalidArgument, - out); - std::array half_sizes; - std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin()); - half_sizes[dim] /= 2; - TensorImpl first_half_impl( - self.scalar_type(), - self.dim(), - half_sizes.data(), - self.mutable_data_ptr(), - const_cast(self.dim_order().data()), - const_cast(self.strides().data()), - self.shape_dynamism()); - TensorImpl second_half_impl( - self.scalar_type(), - self.dim(), - half_sizes.data(), - reinterpret_cast(self.mutable_data_ptr()) + - self.strides()[dim] * self_size / 2 * self.element_size(), - const_cast(self.dim_order().data()), - const_cast(self.strides().data()), - self.shape_dynamism()); - Tensor first_half(&first_half_impl); - Tensor second_half(&second_half_impl); - ScalarType compute_type = - executorch::runtime::isFloatingType(self.scalar_type()) - ? self.scalar_type() - : ScalarType::Float; - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "glu.out"; - ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE { - // TODO: rewrite this to be vectorization-capable. - const auto one = static_cast(1.0); - return val_a * (one / (one + std::exp(-val_b))); - }, - ctx, - first_half, - utils::SupportedTensorDtypes::FLOATHBF16, - second_half, - utils::SupportedTensorDtypes::FLOATHBF16, - out, - utils::internal::SupportNoncontiguousTensors()); - }); + slice_tensor(self, dim, self_size / 2, self_size, out); + sigmoid_tensor(out); + mul_tensors(self, dim, out); return out; } } // namespace @@ -124,7 +158,7 @@ Tensor& glu_out( ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() { - glu_out_tensor(ctx, self, non_negative_dim, out); + glu_out_tensor(self, non_negative_dim, out); }); }); diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index d372767819a..4d3ba46b51b 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s( std::equal(lhs_begin, lhs_end, rhs_begin); } -template +template class BroadcastIndexesIterator { public: using difference_type = ssize_t; @@ -57,11 +57,8 @@ class BroadcastIndexesIterator { template explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args) : output_dim_or_zero_if_no_broadcasting_( - !support_noncontiguous_tensors && - (sizes_match_ignoring_leading_1s( - args.sizes(), - output.sizes()) && - ...) + (sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) && + ...) ? 0 : output.dim()), output_shape_(output.sizes()) { @@ -69,8 +66,7 @@ class BroadcastIndexesIterator { sizeof...(args) == kNumInputs && (std::is_same_v && ...), "BroadcastIndexesIterator constructor requires kNumInputs input tensor" "arguments!"); - if (support_noncontiguous_tensors || - output_dim_or_zero_if_no_broadcasting_ != 0) { + if (output_dim_or_zero_if_no_broadcasting_ != 0) { effective_input_broadcast_strides_ = { effective_input_broadcast_stride(output, args)...}; } @@ -253,17 +249,11 @@ class BroadcastIndexesIterator { * Unlike looping using delinearize_index() and * linearize_access_indexes(), BroadcastIndexesRange avoids expensive * division and modulo operations on each iteration. - * - * The support_noncontiguous_tensors argument disables an optimization - * that causes the iterators not to respect strides in some - * cases. This optimization is normally safe because ExecuTorch - * tensors are contiguous. */ -template +template class BroadcastIndexesRange { public: - using iterator = internal:: - BroadcastIndexesIterator; + using iterator = internal::BroadcastIndexesIterator; template BroadcastIndexesRange(const Tensor& output, const Args&... args) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 722483ec363..e30b8af7d89 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,19 +51,9 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { -/** - * Causes these utility functions to make sure to respect Tensor - * strides; normally, this is not strictly necessary because ExecuTorch - * Tensors are contiguous. - */ -struct SupportNoncontiguousTensors { - explicit SupportNoncontiguousTensors() = default; -}; - template < typename CTYPE_COMPUTE, typename CTYPE_OUT, - bool support_noncontiguous_tensors, typename Op, typename... Args> inline void dtype_specialized_elementwise_fn_impl( @@ -85,8 +75,7 @@ inline void dtype_specialized_elementwise_fn_impl( CTYPE_OUT* const data_out = out.mutable_data_ptr(); const auto range = - BroadcastIndexesRange( - out, (*inputs.first)...); + BroadcastIndexesRange(out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -128,7 +117,6 @@ inline bool validate_elementwise_fn_inputs( template < typename CTYPE_COMPUTE, const char* op_name, - bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn_generic_impl( @@ -163,8 +151,7 @@ inline void apply_elementwise_fn_generic_impl( ::executorch::extension::internal::GRAIN_SIZE, [&](const auto begin, const auto end) { const auto range = - BroadcastIndexesRange( - out, (*inputs.first)...); + BroadcastIndexesRange(out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -200,10 +187,7 @@ inline void apply_elementwise_fn_runtime_out_dtypes( return; } - apply_elementwise_fn_generic_impl< - CTYPE_COMPUTE, - op_name, - /*support_noncontiguous_tensors*/ false>( + apply_elementwise_fn_generic_impl( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -211,7 +195,6 @@ template < typename CTYPE_COMPUTE, const char* op_name, SupportedTensorDtypes out_dtypes, - bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn( @@ -235,17 +218,12 @@ inline void apply_elementwise_fn( out.scalar_type() == out_specialized_scalar_type) { using CTYPE_OUT = typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + dtype_specialized_elementwise_fn_impl( + compute_fun, ctx, out, inputs...); return; } - apply_elementwise_fn_generic_impl< - CTYPE_COMPUTE, - op_name, - support_noncontiguous_tensors>( + apply_elementwise_fn_generic_impl( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -273,31 +251,7 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ false>( - compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); -} - -template < - typename CTYPE_COMPUTE, - const char* op_name, - SupportedTensorDtypes out_dtypes, - typename Op> -inline void apply_unitensor_elementwise_fn( - const Op& compute_fun, - KernelRuntimeContext& ctx, - const Tensor& a, - SupportedTensorDtypes a_dtypes, - const Tensor& out, - SupportNoncontiguousTensors) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ true>( + internal::apply_elementwise_fn( compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); } @@ -341,37 +295,7 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ false>( - compute_fun, - ctx, - out, - std::make_pair(&a, a_dtypes), - std::make_pair(&b, b_dtypes)); -} - -template < - typename CTYPE_COMPUTE, - const char* op_name, - SupportedTensorDtypes out_dtypes, - typename Op> -inline void apply_bitensor_elementwise_fn( - const Op& compute_fun, - KernelRuntimeContext& ctx, - const Tensor& a, - SupportedTensorDtypes a_dtypes, - const Tensor& b, - SupportedTensorDtypes b_dtypes, - const Tensor& out, - SupportNoncontiguousTensors) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ true>( + internal::apply_elementwise_fn( compute_fun, ctx, out, @@ -439,40 +363,7 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ false>( - compute_fun, - ctx, - out, - std::make_pair(&a, a_dtypes), - std::make_pair(&b, b_dtypes), - std::make_pair(&c, c_dtypes)); -} - -template < - typename CTYPE_COMPUTE, - const char* op_name, - SupportedTensorDtypes out_dtypes, - typename Op> -inline void apply_tritensor_elementwise_fn( - const Op& compute_fun, - KernelRuntimeContext& ctx, - const Tensor& a, - SupportedTensorDtypes a_dtypes, - const Tensor& b, - SupportedTensorDtypes b_dtypes, - const Tensor& c, - SupportedTensorDtypes c_dtypes, - const Tensor& out, - SupportNoncontiguousTensors) { - internal::apply_elementwise_fn< - CTYPE_COMPUTE, - op_name, - out_dtypes, - /*support_noncontiguous_tensors*/ true>( + internal::apply_elementwise_fn( compute_fun, ctx, out, diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index ac931302f98..b18117eaa4e 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -28,10 +28,9 @@ class OpGluOutTest : public OperatorTest { return torch::executor::aten::glu_outf(context_, self, dim, out); } - template + template void expect_tensor_close(Tensor actual, Tensor expected) { - if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16 || - OUT_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::BFloat16) { + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { EXPECT_TENSOR_CLOSE_WITH_TOL( actual, expected, @@ -55,14 +54,14 @@ class OpGluOutTest : public OperatorTest { Tensor in = tf.make(sizes, {0, 1, 2, 3, 4, 5, 6, 7}); Tensor out = tf_out.zeros(out_sizes_1); op_glu_out(in, 0, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( out_sizes_1, /*data=*/{0, 0.99330717, 1.99505484, 2.99726701})); const std::vector out_sizes_2 = {4, 1}; out = tf_out.zeros(out_sizes_2); op_glu_out(in, 1, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( out_sizes_2, /*data=*/{0, 1.90514827, 3.97322869, 5.99453402})); diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index 96941590dd4..a731ce5c674 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -618,7 +618,6 @@ ATEN_OPS = ( name = "op_glu", deps = [ "//executorch/kernels/portable/cpu/util:activation_ops_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", ],