1414#include < executorch/kernels/portable/cpu/util/dtype_util.h>
1515#include < executorch/runtime/kernel/kernel_runtime_context.h>
1616
17+ #include < array>
18+ #include < utility>
19+
1720namespace torch {
1821namespace executor {
1922namespace native {
@@ -46,38 +49,94 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
4649 : s.to <int64_t >();
4750}
4851
49- template <typename CTYPE_COMMON, const char * op_name, typename Op>
50- inline void apply_unitensor_elementwise_fn (
52+ namespace internal {
53+ template <
54+ typename CTYPE_COMMON,
55+ const char * op_name,
56+ typename Op,
57+ typename ... Args>
58+ inline void apply_elementwise_fn (
5159 const Op& compute_fun,
5260 KernelRuntimeContext& ctx,
53- const Tensor& a,
54- SupportedTensorDtypes a_dtypes,
5561 const Tensor& out,
56- SupportedTensorDtypes out_dtypes) {
62+ SupportedTensorDtypes out_dtypes,
63+ Args... inputs) {
64+ static_assert (
65+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
66+ ...));
67+ constexpr auto kNumInputs = sizeof ...(inputs);
5768 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
58-
69+ const auto check_input_dtype = [](auto input, auto compute_type) {
70+ return internal::check_tensor_dtype (
71+ *input.first , input.second , compute_type);
72+ };
5973 ET_KERNEL_CHECK (
6074 ctx,
61- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
62- internal::check_tensor_dtype (out, out_dtypes, compute_type) ),
75+ (check_input_dtype (inputs, compute_type) && ... ) &&
76+ internal::check_tensor_dtype (out, out_dtypes, compute_type),
6377 InvalidArgument, );
6478
65- const auto load_a_to_common =
66- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
79+ bool any_is_broadcasted = false ;
80+ if constexpr (kNumInputs > 1 ) {
81+ any_is_broadcasted = (!out.sizes ().equals (inputs.first ->sizes ()) || ...);
82+ }
83+
84+ struct InputInfo {
85+ load_to_common_fn<CTYPE_COMMON> load_to_common;
86+ const char * data_ptr;
87+ ssize_t element_size;
88+ };
89+ std::array<InputInfo, kNumInputs > inputs_info = {(InputInfo{
90+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(
91+ *inputs.first , inputs.second ),
92+ reinterpret_cast <const char *>(inputs.first ->const_data_ptr ()),
93+ inputs.first ->element_size (),
94+ })...};
95+
6796 const auto store_common_to_out =
6897 internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
6998 out, out_dtypes);
70- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
71- const auto a_element_size = a.element_size ();
72- const auto out_element_size = out.element_size ();
7399 char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
100+ const auto out_element_size = out.element_size ();
74101
75- auto out_numel = out.numel ();
76- for (const auto i : c10::irange (out_numel)) {
77- auto result = compute_fun (load_a_to_common (&data_a[i * a_element_size]));
78- store_common_to_out (result, &data_out[i * out_element_size]);
102+ if (any_is_broadcasted) {
103+ for (const auto & indexes :
104+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...)) {
105+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
106+ for (const auto idx : c10::irange (kNumInputs )) {
107+ const auto & input_info = inputs_info[idx];
108+ loaded_inputs[idx] = input_info.load_to_common (
109+ &input_info.data_ptr [indexes[idx + 1 ] * input_info.element_size ]);
110+ }
111+ auto result = std::apply (compute_fun, loaded_inputs);
112+ store_common_to_out (result, &data_out[indexes[0 ] * out_element_size]);
113+ }
114+ } else {
115+ for (const auto i : c10::irange (out.numel ())) {
116+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
117+ for (const auto idx : c10::irange (kNumInputs )) {
118+ const auto & input_info = inputs_info[idx];
119+ loaded_inputs[idx] = input_info.load_to_common (
120+ &input_info.data_ptr [i * input_info.element_size ]);
121+ }
122+ auto result = std::apply (compute_fun, loaded_inputs);
123+ store_common_to_out (result, &data_out[i * out_element_size]);
124+ }
79125 }
80126}
127+ } // namespace internal
128+
129+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
130+ inline void apply_unitensor_elementwise_fn (
131+ const Op& compute_fun,
132+ KernelRuntimeContext& ctx,
133+ const Tensor& a,
134+ SupportedTensorDtypes a_dtypes,
135+ const Tensor& out,
136+ SupportedTensorDtypes out_dtypes) {
137+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
138+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
139+ }
81140
82141/* *
83142 * Useful for bi-tensor elementwise operators. For each element of the inputs,
@@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn(
94153 SupportedTensorDtypes b_dtypes,
95154 const Tensor& out,
96155 SupportedTensorDtypes out_dtypes) {
97- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
98-
99- ET_KERNEL_CHECK (
156+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
157+ compute_fun,
100158 ctx,
101- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
102- internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
103- internal::check_tensor_dtype (out, out_dtypes, compute_type)),
104- InvalidArgument, );
105-
106- const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
107- const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
108- const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
109-
110- const auto load_a_to_common =
111- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
112- const auto load_b_to_common =
113- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
114- const auto store_common_to_out =
115- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
116- out, out_dtypes);
117- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
118- const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
119- const auto a_element_size = a.element_size ();
120- const auto b_element_size = b.element_size ();
121- const auto out_element_size = out.element_size ();
122- char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
123-
124- auto out_numel = out.numel ();
125- if (any_is_broadcasted) {
126- for (const auto [out_index, a_index, b_index] :
127- BroadcastIndexesRange<2 >(out, a, b)) {
128- auto result = compute_fun (
129- load_a_to_common (&data_a[a_index * a_element_size]),
130- load_b_to_common (&data_b[b_index * b_element_size]));
131- store_common_to_out (result, &data_out[out_index * out_element_size]);
132- }
133- } else {
134- for (const auto i : c10::irange (out_numel)) {
135- size_t a_linear_index = i;
136- size_t b_linear_index = i;
137-
138- auto result = compute_fun (
139- load_a_to_common (&data_a[a_linear_index * a_element_size]),
140- load_b_to_common (&data_b[b_linear_index * b_element_size]));
141- store_common_to_out (result, &data_out[i * out_element_size]);
142- }
143- }
159+ out,
160+ out_dtypes,
161+ std::make_pair (&a, a_dtypes),
162+ std::make_pair (&b, b_dtypes));
144163}
145164
146165/* *
@@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn(
175194 SupportedTensorDtypes c_dtypes,
176195 const Tensor& out,
177196 SupportedTensorDtypes out_dtypes) {
178- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179-
180- ET_KERNEL_CHECK (
197+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
198+ compute_fun,
181199 ctx,
182- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
183- internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
184- internal::check_tensor_dtype (c, c_dtypes, compute_type) &&
185- internal::check_tensor_dtype (out, out_dtypes, compute_type)),
186- InvalidArgument, );
187-
188- const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
189- const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
190- const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
191- const bool any_is_broadcasted =
192- (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193-
194- const auto load_a_to_common =
195- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196- const auto load_b_to_common =
197- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198- const auto load_c_to_common =
199- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200- const auto store_common_to_out =
201- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202- out, out_dtypes);
203- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
204- const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
205- const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
206- const auto a_element_size = a.element_size ();
207- const auto b_element_size = b.element_size ();
208- const auto c_element_size = c.element_size ();
209- const auto out_element_size = out.element_size ();
210- char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
211-
212- auto out_numel = out.numel ();
213- if (any_is_broadcasted) {
214- for (const auto [out_index, a_index, b_index, c_index] :
215- BroadcastIndexesRange<3 >(out, a, b, c)) {
216- auto result = compute_fun (
217- load_a_to_common (&data_a[a_index * a_element_size]),
218- load_b_to_common (&data_b[b_index * b_element_size]),
219- load_c_to_common (&data_c[c_index * c_element_size]));
220- store_common_to_out (result, &data_out[out_index * out_element_size]);
221- }
222- } else {
223- for (const auto i : c10::irange (out_numel)) {
224- size_t a_linear_index = i;
225- size_t b_linear_index = i;
226- size_t c_linear_index = i;
227-
228- auto result = compute_fun (
229- load_a_to_common (&data_a[a_linear_index * a_element_size]),
230- load_b_to_common (&data_b[b_linear_index * b_element_size]),
231- load_c_to_common (&data_c[c_linear_index * c_element_size]));
232- store_common_to_out (result, &data_out[i * out_element_size]);
233- }
234- }
200+ out,
201+ out_dtypes,
202+ std::make_pair (&a, a_dtypes),
203+ std::make_pair (&b, b_dtypes),
204+ std::make_pair (&c, c_dtypes));
235205}
236206
237207inline ScalarType get_compute_type (ScalarType& common_type) {
0 commit comments