From f2789765779bda796a657c41c10451dfb0134e78 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 10 Mar 2025 16:20:30 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- kernels/portable/cpu/util/reduce_util.h | 126 +++++++++++++++++------- 1 file changed, 89 insertions(+), 37 deletions(-) diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 35cfdfbaa72..4f03dac1009 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -45,7 +45,7 @@ template void apply_on_flat_ix_with_dim_mask_and_base( const Fn& fn, const Tensor& in, - bool* dim_mask, + const bool* dim_mask, const size_t base, const size_t start, const size_t end) { @@ -295,6 +295,92 @@ void apply_over_dim( } } +/** + * Execution plan for repeated apply_over_dim_list with the same + * function, input tensor, dim list, start, and end but varying + * out_ix, as done (via {map_,}reduce_over_dim_list) in reductions. + */ +class ApplyOverDimListPlan { + public: + ApplyOverDimListPlan( + const executorch::aten::Tensor& in, + // If set, lifetime must last until execute() returns. + const executorch::aten::optional>& + dim_list, + const int64_t start = 0, + const int64_t end = -1) + : in_(in) { + ET_CHECK(check_dim_list_is_valid(in, dim_list)); + out_numel_ = get_out_numel(in_, dim_list); + if (in.numel() == 0) { + mode_ = ExecutionMode::NothingToDo; + return; + } + const size_t iter_length = get_reduced_dim_product(in, dim_list); + const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length); + const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length); + ustart_ = std::max(normalized_start, size_t(0)); + uend_ = std::min(normalized_end, iter_length - 1); + if (!dim_list.has_value() || dim_list.value().size() == 0 || + in.dim() == 0) { + mode_ = ExecutionMode::NoDimMaskOrZeroDimension; + return; + } + dim_list_ = dim_list.value(); + is_in_dim_list_.fill(0); + for (const auto& d : dim_list.value()) { + const size_t non_neg_d = d < 0 ? d + in.dim() : d; + is_in_dim_list_[non_neg_d] = true; + } + + mode_ = ExecutionMode::NormalDimMask; + } + + template + void execute(const Fn& fn, const size_t out_ix) const { + ET_CHECK_MSG(out_ix < out_numel_, "Out index %zd is out of bounds", out_ix); + + switch (mode_) { + case ExecutionMode::NothingToDo: + return; + case ExecutionMode::NoDimMaskOrZeroDimension: + apply_on_flat_ix_with_stride_and_base( + fn, /*stride=*/1, /*base=*/0, ustart_, uend_); + return; + case ExecutionMode::NormalDimMask: + apply_on_flat_ix_with_dim_mask_and_base( + fn, + in_, + is_in_dim_list_.data(), + get_init_index(in_, dim_list_, out_ix), + ustart_, + uend_); + return; + } + } + + private: + // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base. + size_t ustart_; + // End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base. + size_t uend_; + enum class ExecutionMode { + // Empty input, no work to do. + NothingToDo, + // Iterate over the entire tensor with + // apply_on_flat_ix_with_stride_and_base. + NoDimMaskOrZeroDimension, + // General mode, iterate with + // apply_on_flat_ix_with_dim_mask_and_base. + NormalDimMask + }; + ExecutionMode mode_; + size_t out_numel_; + executorch::aten::ArrayRef dim_list_; + std::array is_in_dim_list_; + const executorch::aten::Tensor& in_; +}; + /** * Useful to reduce a tensor `in` over a given list of dimensions `dim_list` * for the output element at index `out_ix` using the reduce function @@ -311,42 +397,8 @@ void apply_over_dim_list( const size_t out_ix, const int64_t start = 0, const int64_t end = -1) { - ET_CHECK(check_dim_list_is_valid(in, dim_list)); - ET_CHECK_MSG( - out_ix < get_out_numel(in, dim_list), - "Out index %zd is out of bounds", - out_ix); - - if (in.numel() == 0) { - return; - } - - const size_t iter_length = get_reduced_dim_product(in, dim_list); - const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length); - const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length); - const size_t ustart = std::max(normalized_start, size_t(0)); - const size_t uend = std::min(normalized_end, iter_length - 1); - - // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor - if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) { - apply_on_flat_ix_with_stride_and_base( - fn, /*stride=*/1, /*base=*/0, ustart, uend); - return; - } - - // Create is_in_dims to check whether each dimension is in the dim list - bool is_in_dim_list[kTensorDimensionLimit]; - memset(is_in_dim_list, false, sizeof(is_in_dim_list)); - for (const auto& d : dim_list.value()) { - const size_t non_neg_d = d < 0 ? d + in.dim() : d; - is_in_dim_list[non_neg_d] = true; - } - - // Compute the starting base index - const size_t base = get_init_index(in, dim_list, out_ix); - - apply_on_flat_ix_with_dim_mask_and_base( - fn, in, is_in_dim_list, base, ustart, uend); + ApplyOverDimListPlan plan(in, dim_list, start, end); + plan.execute(fn, out_ix); } // From 60e6ce3f240077785baa5c122da50e472796e6a8 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 10 Mar 2025 16:20:35 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- kernels/portable/cpu/util/reduce_util.h | 92 ++++++++++++++++--------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 4f03dac1009..f299ff8f135 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -309,7 +309,7 @@ class ApplyOverDimListPlan { dim_list, const int64_t start = 0, const int64_t end = -1) - : in_(in) { + : dim_list_(dim_list), in_(in) { ET_CHECK(check_dim_list_is_valid(in, dim_list)); out_numel_ = get_out_numel(in_, dim_list); if (in.numel() == 0) { @@ -352,13 +352,22 @@ class ApplyOverDimListPlan { fn, in_, is_in_dim_list_.data(), - get_init_index(in_, dim_list_, out_ix), + get_init_index(in_, dim_list_.value(), out_ix), ustart_, uend_); return; } } + const executorch::aten::Tensor& get_input_tensor() const { + return in_; + } + + const executorch::aten::optional>& + get_dim_list() const { + return dim_list_; + } + private: // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base. size_t ustart_; @@ -376,7 +385,7 @@ class ApplyOverDimListPlan { }; ExecutionMode mode_; size_t out_numel_; - executorch::aten::ArrayRef dim_list_; + executorch::aten::optional> dim_list_; std::array is_in_dim_list_; const executorch::aten::Tensor& in_; }; @@ -482,6 +491,52 @@ std::tuple map_reduce_over_dim( return std::tuple{acc_val, acc_ix}; } +/** + * Execution plan for repeated map_reduce_over_dim_list with the same + * function, input tensor, and dim_list but varying out_ix. + */ +class MapReduceOverDimListPlan { + public: + MapReduceOverDimListPlan( + const executorch::aten::Tensor& in, + const executorch::aten::optional>& + dim_list) + : plan_(in, dim_list, 1, -1) { + ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty"); + } + + template < + typename CTYPE_IN, + typename CTYPE_OUT, + typename MapOp, + typename ReduceOp> + CTYPE_OUT execute( + const MapOp& map_fun, + const ReduceOp& reduce_fun, + const size_t out_ix) const { + const size_t init_index = + get_init_index(plan_.get_input_tensor(), plan_.get_dim_list(), out_ix); + + const CTYPE_IN* const in_data = + plan_.get_input_tensor().const_data_ptr(); + CTYPE_OUT acc_val = map_fun(in_data[init_index]); + + if (plan_.get_input_tensor().numel() == 1) { + return acc_val; + } + + plan_.execute( + [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) { + acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val); + }, + out_ix); + return acc_val; + } + + private: + ApplyOverDimListPlan plan_; +}; + /** * Useful to reduce a tensor `in` over a given list of dimensions `dim_list` * for the output element at index `out_ix`, first applying the map `map_fun` @@ -517,35 +572,8 @@ CTYPE_OUT map_reduce_over_dim_list( const executorch::aten::optional>& dim_list, const size_t out_ix) { - ET_CHECK(check_dim_list_is_valid(in, dim_list)); - - ET_CHECK_MSG( - out_ix < get_out_numel(in, dim_list), - "Out index %zd is out of bounds", - out_ix); - - ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty"); - - const size_t init_index = get_init_index(in, dim_list, out_ix); - - const CTYPE_IN* const in_data = in.const_data_ptr(); - CTYPE_OUT acc_val = map_fun(in_data[init_index]); - - if (in.numel() == 1) { - return acc_val; - } - - apply_over_dim_list( - [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) { - acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val); - }, - in, - dim_list, - out_ix, - 1, - -1); - - return acc_val; + MapReduceOverDimListPlan plan(in, dim_list); + return plan.execute(map_fun, reduce_fun, out_ix); } /** From 1d43a907893683803627916f430b01b1be043732 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 10 Mar 2025 16:20:39 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- kernels/portable/cpu/util/reduce_util.h | 26 +++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index f299ff8f135..6e9a8540286 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -606,6 +606,28 @@ std::tuple reduce_over_dim( [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix); } +/** + * Execution plan for repeated reduce_over_dim_list with the same + * function, input tensor, and dim_list but varying out_ix. + */ +class ReduceOverDimListPlan { + public: + ReduceOverDimListPlan( + const executorch::aten::Tensor& in, + const executorch::aten::optional>& + dim_list) + : plan_(in, dim_list) {} + + template + CTYPE execute(const ReduceOp& reduce_fun, const size_t out_ix) { + return plan_.execute( + [](CTYPE v) { return v; }, reduce_fun, out_ix); + } + + private: + MapReduceOverDimListPlan plan_; +}; + /** * Useful to reduce a tensor `in` over a given list of dimensions `dim_list` * for the output element at index `out_ix` using the reduce function @@ -632,8 +654,8 @@ CTYPE reduce_over_dim_list( const executorch::aten::optional>& dim_list, const size_t out_ix) { - return map_reduce_over_dim_list( - [](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix); + ReduceOverDimListPlan plan(in, dim_list); + return plan.execute(reduce_fun, out_ix); } // From 22d13daa0ec38caf560fc231f9d668b60a5383c5 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 10 Mar 2025 16:20:44 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- kernels/portable/cpu/op_amax.cpp | 5 ++--- kernels/portable/cpu/op_amin.cpp | 5 ++--- kernels/portable/cpu/op_any.cpp | 11 +++++++---- kernels/portable/cpu/op_mean.cpp | 5 ++--- kernels/portable/cpu/op_sum.cpp | 10 ++++++---- kernels/portable/cpu/op_var.cpp | 9 +++------ 6 files changed, 22 insertions(+), 23 deletions(-) diff --git a/kernels/portable/cpu/op_amax.cpp b/kernels/portable/cpu/op_amax.cpp index d36f416c7b4..6030221d883 100644 --- a/kernels/portable/cpu/op_amax.cpp +++ b/kernels/portable/cpu/op_amax.cpp @@ -43,15 +43,14 @@ Tensor& amax_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() { CTYPE* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { - out_data[out_ix] = reduce_over_dim_list( + out_data[out_ix] = plan.execute( [](CTYPE v, CTYPE max_v) { return std::isnan(v) || v > max_v ? v : max_v; }, - in, - dim_list, out_ix); } }); diff --git a/kernels/portable/cpu/op_amin.cpp b/kernels/portable/cpu/op_amin.cpp index 7c4c8186e59..e4979390a5d 100644 --- a/kernels/portable/cpu/op_amin.cpp +++ b/kernels/portable/cpu/op_amin.cpp @@ -42,15 +42,14 @@ Tensor& amin_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() { CTYPE* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { - out_data[out_ix] = reduce_over_dim_list( + out_data[out_ix] = plan.execute( [](CTYPE v, CTYPE min_v) { return std::isnan(v) || v < min_v ? v : min_v; }, - in, - dim_list, out_ix); } }); diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index 2cfdf36740b..52cfe764703 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -79,6 +79,11 @@ Tensor& any_dims_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "any.dims_out"; + const bool in_not_empty = in.numel() > 0; + std::optional plan; + if ((!dim_list.has_value() || !dim_list.value().empty()) && in_not_empty) { + plan.emplace(in, dim_list); + } ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); @@ -91,12 +96,10 @@ Tensor& any_dims_out( } else { for (const auto out_ix : c10::irange(out.numel())) { bool any = false; - if (in.numel() > 0) { - any = map_reduce_over_dim_list( + if (in_not_empty) { + any = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](bool outv, bool acc) { return acc || outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = static_cast(any); diff --git a/kernels/portable/cpu/op_mean.cpp b/kernels/portable/cpu/op_mean.cpp index 77f74ae7cac..c13e2a09937 100644 --- a/kernels/portable/cpu/op_mean.cpp +++ b/kernels/portable/cpu/op_mean.cpp @@ -45,6 +45,7 @@ Tensor& mean_dim_out( InvalidArgument, out); + MapReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { ET_SWITCH_FLOATHBF16_TYPES( out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { @@ -53,11 +54,9 @@ Tensor& mean_dim_out( for (const auto out_ix : c10::irange(out.numel())) { CTYPE_OUT sum = 0; if (in.numel() > 0) { - sum = map_reduce_over_dim_list( + sum = plan.execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = sum / static_cast(num); diff --git a/kernels/portable/cpu/op_sum.cpp b/kernels/portable/cpu/op_sum.cpp index 81cf4b5a175..3bc1ba98dc5 100644 --- a/kernels/portable/cpu/op_sum.cpp +++ b/kernels/portable/cpu/op_sum.cpp @@ -44,6 +44,10 @@ Tensor& sum_dim_out( ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + std::optional plan; + if (in.numel() > 0) { + plan.emplace(in, dim_list); + } ET_SWITCH_REALHBBF16_TYPES( in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] { ET_SWITCH_REALHBBF16_TYPES( @@ -51,12 +55,10 @@ Tensor& sum_dim_out( CTYPE_OUT* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { CTYPE_OUT sum = 0; - if (in.numel() > 0) { - sum = map_reduce_over_dim_list( + if (plan.has_value()) { + sum = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = sum; diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index 0cffca450c8..c5be3fdad62 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -32,23 +32,20 @@ void compute_variance( out_data[out_ix] = NAN; } } else { + MapReduceOverDimListPlan plan(in, dim_list); for (const auto out_ix : c10::irange(out.numel())) { - CTYPE_OUT sum = map_reduce_over_dim_list( + CTYPE_OUT sum = plan.execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); CTYPE_OUT mean = sum / static_cast(num); - CTYPE_OUT sum2 = map_reduce_over_dim_list( + CTYPE_OUT sum2 = plan.execute( [mean](CTYPE_IN v) { return ( (static_cast(v) - mean) * (static_cast(v) - mean)); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); out_data[out_ix] = sum2 / denominator; } From f60b959d9178f521607e1573cd89a24e4a13c93b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 10 Mar 2025 16:59:15 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- kernels/portable/cpu/op_any.cpp | 2 ++ kernels/portable/cpu/op_sum.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index 52cfe764703..a9dd79ad34d 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { diff --git a/kernels/portable/cpu/op_sum.cpp b/kernels/portable/cpu/op_sum.cpp index 3bc1ba98dc5..f58773a6769 100644 --- a/kernels/portable/cpu/op_sum.cpp +++ b/kernels/portable/cpu/op_sum.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native {