1515#include < executorch/runtime/kernel/kernel_runtime_context.h>
1616#include < executorch/runtime/kernel/thread_parallel_interface.h>
1717
18+ #ifdef ET_USE_PYTORCH_HEADERS
19+ #include < ATen/cpu/vec/vec.h>
20+ #endif // ET_USE_PYTORCH_HEADERS
21+
1822#include < array>
1923#include < utility>
2024
@@ -58,6 +62,19 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
5862using op_call_result =
5963 std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6064
65+ #ifdef ET_USE_PYTORCH_HEADERS
66+ // Can I call a function of type Op with sizeof...(Args) arguments of type
67+ // at::vec::Vectorized<CTYPE_COMMON>?
68+ //
69+ // See [NOTE: Generic lambdas] below for requirements on Op.
70+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
71+ constexpr bool can_use_vectorized () {
72+ return std::is_invocable_v<
73+ Op,
74+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75+ }
76+ #endif // ET_USE_PYTORCH_HEADERS
77+
6178template <
6279 typename CTYPE_COMMON,
6380 typename CTYPE_OUT,
@@ -68,14 +85,72 @@ inline void dtype_specialized_elementwise_fn_impl(
6885 KernelRuntimeContext& ctx,
6986 const Tensor& out,
7087 Args... inputs) {
88+ static_assert (
89+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
90+ ...));
7191 constexpr auto kNumInputs = sizeof ...(inputs);
72- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
92+ // All inputs must be of type CTYPE_COMMON.
93+ ET_DCHECK (
94+ ((inputs.first ->scalar_type () ==
95+ CppTypeToScalarType<CTYPE_COMMON>::value) &&
96+ ...));
7397
7498 std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
7599 inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76100
77101 CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78102
103+ #ifdef ET_USE_PYTORCH_HEADERS
104+ if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>()) {
105+ const bool any_is_broadcasted =
106+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
107+ inputs.first ->sizes (), out.sizes ()) &&
108+ ...);
109+ if (!any_is_broadcasted) {
110+ using Vec = at::vec::Vectorized<CTYPE_COMMON>;
111+ ::executorch::extension::parallel_for (
112+ 0 ,
113+ out.numel(),
114+ ::executorch::extension::internal::GRAIN_SIZE,
115+ [&](const auto begin, const auto end) {
116+ const auto vectorized_begin =
117+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
118+ const auto vectorized_end = end - (end % Vec::size ());
119+ // Scalar prologue.
120+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
121+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
122+ for (const auto input_idx : c10::irange (kNumInputs )) {
123+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
124+ }
125+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
126+ }
127+
128+ // Main vectorized loop.
129+ for (auto idx = vectorized_begin; idx < vectorized_end;
130+ idx += Vec::size ()) {
131+ std::array<Vec, kNumInputs > loaded_vec_inputs;
132+ for (const auto input_idx : c10::irange (kNumInputs )) {
133+ loaded_vec_inputs[input_idx] =
134+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
135+ }
136+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
137+ result_vec.store (&data_out[idx]);
138+ }
139+
140+ // Scalar epilogue.
141+ for (const auto idx : c10::irange (vectorized_end, end)) {
142+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
143+ for (const auto input_idx : c10::irange (kNumInputs )) {
144+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
145+ }
146+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
147+ }
148+ });
149+ return ;
150+ }
151+ }
152+ #endif
153+
79154 ::executorch::extension::parallel_for (
80155 0 ,
81156 out.numel(),
@@ -255,6 +330,19 @@ inline void apply_unitensor_elementwise_fn(
255330 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
256331}
257332
333+ /* *
334+ * Useful for unary elementwise operators. For each element of the
335+ * input, call Op and write to the corresponding element of the
336+ * output. Tensor broadcasting is applied wherever it is required.
337+ *
338+ * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339+ * parameters; normal lambdas are fine), it must fulfill one of the
340+ * following conditions. Either:
341+ * 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342+ * 2) It must be actively SFINAE-friendly, as per the C++17 examples in
343+ * https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
344+ * .
345+ */
258346template <
259347 typename CTYPE_COMMON,
260348 const char * op_name,
@@ -296,6 +384,8 @@ inline void apply_bitensor_elementwise_fn(
296384 * Useful for bi-tensor elementwise operators. For each element of the inputs,
297385 * perform a computation and write to the corresponding element of the output.
298386 * Tensor broadcasting is applied wherever it is required.
387+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
388+ * compute_fun.
299389 */
300390template <
301391 typename CTYPE_COMMON,
@@ -362,6 +452,9 @@ inline void apply_tritensor_elementwise_fn(
362452 *
363453 * static constexpr const char op_name[] = "my_op";
364454 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
455+ *
456+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
457+ * compute_fun.
365458 */
366459template <
367460 typename CTYPE_COMMON,
0 commit comments