diff --git a/kernels/portable/cpu/op_amax.cpp b/kernels/portable/cpu/op_amax.cpp index d36f416c7b4..6030221d883 100644 --- a/kernels/portable/cpu/op_amax.cpp +++ b/kernels/portable/cpu/op_amax.cpp @@ -43,15 +43,14 @@ Tensor& amax_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() { CTYPE* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { - out_data[out_ix] = reduce_over_dim_list( + out_data[out_ix] = plan.execute( [](CTYPE v, CTYPE max_v) { return std::isnan(v) || v > max_v ? v : max_v; }, - in, - dim_list, out_ix); } }); diff --git a/kernels/portable/cpu/op_amin.cpp b/kernels/portable/cpu/op_amin.cpp index 7c4c8186e59..e4979390a5d 100644 --- a/kernels/portable/cpu/op_amin.cpp +++ b/kernels/portable/cpu/op_amin.cpp @@ -42,15 +42,14 @@ Tensor& amin_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() { CTYPE* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { - out_data[out_ix] = reduce_over_dim_list( + out_data[out_ix] = plan.execute( [](CTYPE v, CTYPE min_v) { return std::isnan(v) || v < min_v ? v : min_v; }, - in, - dim_list, out_ix); } }); diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index 2cfdf36740b..a9dd79ad34d 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { @@ -79,6 +81,11 @@ Tensor& any_dims_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "any.dims_out"; + const bool in_not_empty = in.numel() > 0; + std::optional plan; + if ((!dim_list.has_value() || !dim_list.value().empty()) && in_not_empty) { + plan.emplace(in, dim_list); + } ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); @@ -91,12 +98,10 @@ Tensor& any_dims_out( } else { for (const auto out_ix : c10::irange(out.numel())) { bool any = false; - if (in.numel() > 0) { - any = map_reduce_over_dim_list( + if (in_not_empty) { + any = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](bool outv, bool acc) { return acc || outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = static_cast(any); diff --git a/kernels/portable/cpu/op_mean.cpp b/kernels/portable/cpu/op_mean.cpp index 77f74ae7cac..c13e2a09937 100644 --- a/kernels/portable/cpu/op_mean.cpp +++ b/kernels/portable/cpu/op_mean.cpp @@ -45,6 +45,7 @@ Tensor& mean_dim_out( InvalidArgument, out); + MapReduceOverDimListPlan plan(in, dim_list); ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] { ET_SWITCH_FLOATHBF16_TYPES( out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] { @@ -53,11 +54,9 @@ Tensor& mean_dim_out( for (const auto out_ix : c10::irange(out.numel())) { CTYPE_OUT sum = 0; if (in.numel() > 0) { - sum = map_reduce_over_dim_list( + sum = plan.execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = sum / static_cast(num); diff --git a/kernels/portable/cpu/op_sum.cpp b/kernels/portable/cpu/op_sum.cpp index 81cf4b5a175..f58773a6769 100644 --- a/kernels/portable/cpu/op_sum.cpp +++ b/kernels/portable/cpu/op_sum.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { @@ -44,6 +46,10 @@ Tensor& sum_dim_out( ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + std::optional plan; + if (in.numel() > 0) { + plan.emplace(in, dim_list); + } ET_SWITCH_REALHBBF16_TYPES( in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] { ET_SWITCH_REALHBBF16_TYPES( @@ -51,12 +57,10 @@ Tensor& sum_dim_out( CTYPE_OUT* out_data = out.mutable_data_ptr(); for (const auto out_ix : c10::irange(out.numel())) { CTYPE_OUT sum = 0; - if (in.numel() > 0) { - sum = map_reduce_over_dim_list( + if (plan.has_value()) { + sum = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); } out_data[out_ix] = sum; diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index 0cffca450c8..c5be3fdad62 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -32,23 +32,20 @@ void compute_variance( out_data[out_ix] = NAN; } } else { + MapReduceOverDimListPlan plan(in, dim_list); for (const auto out_ix : c10::irange(out.numel())) { - CTYPE_OUT sum = map_reduce_over_dim_list( + CTYPE_OUT sum = plan.execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); CTYPE_OUT mean = sum / static_cast(num); - CTYPE_OUT sum2 = map_reduce_over_dim_list( + CTYPE_OUT sum2 = plan.execute( [mean](CTYPE_IN v) { return ( (static_cast(v) - mean) * (static_cast(v) - mean)); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, - in, - dim_list, out_ix); out_data[out_ix] = sum2 / denominator; } diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 5d14f2752a6..bafb9a4a563 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -329,7 +329,7 @@ class ApplyOverDimListPlan { dim_list, const int64_t start = 0, const int64_t end = -1) - : in_(in) { + : dim_list_(dim_list), in_(in) { ET_CHECK(check_dim_list_is_valid(in, dim_list)); out_numel_ = get_out_numel(in_, dim_list); if (in.numel() == 0) { @@ -372,13 +372,22 @@ class ApplyOverDimListPlan { fn, in_, is_in_dim_list_.data(), - get_init_index(in_, dim_list_, out_ix), + get_init_index(in_, dim_list_.value(), out_ix), ustart_, uend_); return; } } + const executorch::aten::Tensor& get_input_tensor() const { + return in_; + } + + const executorch::aten::optional>& + get_dim_list() const { + return dim_list_; + } + private: // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base. size_t ustart_; @@ -396,7 +405,7 @@ class ApplyOverDimListPlan { }; ExecutionMode mode_; size_t out_numel_; - executorch::aten::ArrayRef dim_list_; + executorch::aten::optional> dim_list_; std::array is_in_dim_list_; const executorch::aten::Tensor& in_; }; @@ -502,6 +511,52 @@ std::tuple map_reduce_over_dim( return std::tuple{acc_val, acc_ix}; } +/** + * Execution plan for repeated map_reduce_over_dim_list with the same + * function, input tensor, and dim_list but varying out_ix. + */ +class MapReduceOverDimListPlan { + public: + MapReduceOverDimListPlan( + const executorch::aten::Tensor& in, + const executorch::aten::optional>& + dim_list) + : plan_(in, dim_list, 1, -1) { + ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty"); + } + + template < + typename CTYPE_IN, + typename CTYPE_OUT, + typename MapOp, + typename ReduceOp> + CTYPE_OUT execute( + const MapOp& map_fun, + const ReduceOp& reduce_fun, + const size_t out_ix) const { + const size_t init_index = + get_init_index(plan_.get_input_tensor(), plan_.get_dim_list(), out_ix); + + const CTYPE_IN* const in_data = + plan_.get_input_tensor().const_data_ptr(); + CTYPE_OUT acc_val = map_fun(in_data[init_index]); + + if (plan_.get_input_tensor().numel() == 1) { + return acc_val; + } + + plan_.execute( + [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) { + acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val); + }, + out_ix); + return acc_val; + } + + private: + ApplyOverDimListPlan plan_; +}; + /** * Useful to reduce a tensor `in` over a given list of dimensions `dim_list` * for the output element at index `out_ix`, first applying the map `map_fun` @@ -537,35 +592,8 @@ CTYPE_OUT map_reduce_over_dim_list( const executorch::aten::optional>& dim_list, const size_t out_ix) { - ET_CHECK(check_dim_list_is_valid(in, dim_list)); - - ET_CHECK_MSG( - out_ix < get_out_numel(in, dim_list), - "Out index %zd is out of bounds", - out_ix); - - ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty"); - - const size_t init_index = get_init_index(in, dim_list, out_ix); - - const CTYPE_IN* const in_data = in.const_data_ptr(); - CTYPE_OUT acc_val = map_fun(in_data[init_index]); - - if (in.numel() == 1) { - return acc_val; - } - - apply_over_dim_list( - [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) { - acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val); - }, - in, - dim_list, - out_ix, - 1, - -1); - - return acc_val; + MapReduceOverDimListPlan plan(in, dim_list); + return plan.execute(map_fun, reduce_fun, out_ix); } /** @@ -598,6 +626,28 @@ std::tuple reduce_over_dim( [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix); } +/** + * Execution plan for repeated reduce_over_dim_list with the same + * function, input tensor, and dim_list but varying out_ix. + */ +class ReduceOverDimListPlan { + public: + ReduceOverDimListPlan( + const executorch::aten::Tensor& in, + const executorch::aten::optional>& + dim_list) + : plan_(in, dim_list) {} + + template + CTYPE execute(const ReduceOp& reduce_fun, const size_t out_ix) { + return plan_.execute( + [](CTYPE v) { return v; }, reduce_fun, out_ix); + } + + private: + MapReduceOverDimListPlan plan_; +}; + /** * Useful to reduce a tensor `in` over a given list of dimensions `dim_list` * for the output element at index `out_ix` using the reduce function @@ -624,8 +674,8 @@ CTYPE reduce_over_dim_list( const executorch::aten::optional>& dim_list, const size_t out_ix) { - return map_reduce_over_dim_list( - [](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix); + ReduceOverDimListPlan plan(in, dim_list); + return plan.execute(reduce_fun, out_ix); } //