1414#include  < executorch/kernels/portable/cpu/util/dtype_util.h> 
1515#include  < executorch/runtime/kernel/kernel_runtime_context.h> 
1616
17+ #include  < array> 
18+ #include  < utility> 
19+ 
1720namespace  torch  {
1821namespace  executor  {
1922namespace  native  {
@@ -46,38 +49,94 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
4649                             : s.to <int64_t >();
4750}
4851
49- template  <typename  CTYPE_COMMON, const  char * op_name, typename  Op>
50- inline  void  apply_unitensor_elementwise_fn (
52+ namespace  internal  {
53+ template  <
54+     typename  CTYPE_COMMON,
55+     const  char * op_name,
56+     typename  Op,
57+     typename ... Args>
58+ inline  void  apply_elementwise_fn (
5159    const  Op& compute_fun,
5260    KernelRuntimeContext& ctx,
53-     const  Tensor& a,
54-     SupportedTensorDtypes a_dtypes,
5561    const  Tensor& out,
56-     SupportedTensorDtypes out_dtypes) {
62+     SupportedTensorDtypes out_dtypes,
63+     Args... inputs) {
64+   static_assert (
65+       (std::is_same_v<Args, std::pair<const  Tensor*, SupportedTensorDtypes>> &&
66+        ...));
67+   constexpr  auto  kNumInputs  = sizeof ...(inputs);
5768  constexpr  auto  compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
58- 
69+   const  auto  check_input_dtype = [](auto  input, auto  compute_type) {
70+     return  internal::check_tensor_dtype (
71+         *input.first , input.second , compute_type);
72+   };
5973  ET_KERNEL_CHECK (
6074      ctx,
61-       (internal::check_tensor_dtype (a, a_dtypes,  compute_type) &&
62-        internal::check_tensor_dtype (out, out_dtypes, compute_type) ),
75+       (check_input_dtype (inputs,  compute_type) && ... ) &&
76+            internal::check_tensor_dtype (out, out_dtypes, compute_type),
6377      InvalidArgument, );
6478
65-   const  auto  load_a_to_common =
66-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
79+   bool  any_is_broadcasted = false ;
80+   if  constexpr  (kNumInputs  > 1 ) {
81+     any_is_broadcasted = (!out.sizes ().equals (inputs.first ->sizes ()) || ...);
82+   }
83+ 
84+   struct  InputInfo  {
85+     load_to_common_fn<CTYPE_COMMON> load_to_common;
86+     const  char * data_ptr;
87+     ssize_t  element_size;
88+   };
89+   std::array<InputInfo, kNumInputs > inputs_info = {(InputInfo{
90+       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(
91+           *inputs.first , inputs.second ),
92+       reinterpret_cast <const  char *>(inputs.first ->const_data_ptr ()),
93+       inputs.first ->element_size (),
94+   })...};
95+ 
6796  const  auto  store_common_to_out =
6897      internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
6998          out, out_dtypes);
70-   const  char * const  data_a = reinterpret_cast <const  char *>(a.const_data_ptr ());
71-   const  auto  a_element_size = a.element_size ();
72-   const  auto  out_element_size = out.element_size ();
7399  char * const  data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
100+   const  auto  out_element_size = out.element_size ();
74101
75-   auto  out_numel = out.numel ();
76-   for  (const  auto  i : c10::irange (out_numel)) {
77-     auto  result = compute_fun (load_a_to_common (&data_a[i * a_element_size]));
78-     store_common_to_out (result, &data_out[i * out_element_size]);
102+   if  (any_is_broadcasted) {
103+     for  (const  auto & indexes :
104+          BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...)) {
105+       std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
106+       for  (const  auto  idx : c10::irange (kNumInputs )) {
107+         const  auto & input_info = inputs_info[idx];
108+         loaded_inputs[idx] = input_info.load_to_common (
109+             &input_info.data_ptr [indexes[idx + 1 ] * input_info.element_size ]);
110+       }
111+       auto  result = std::apply (compute_fun, loaded_inputs);
112+       store_common_to_out (result, &data_out[indexes[0 ] * out_element_size]);
113+     }
114+   } else  {
115+     for  (const  auto  i : c10::irange (out.numel ())) {
116+       std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
117+       for  (const  auto  idx : c10::irange (kNumInputs )) {
118+         const  auto & input_info = inputs_info[idx];
119+         loaded_inputs[idx] = input_info.load_to_common (
120+             &input_info.data_ptr [i * input_info.element_size ]);
121+       }
122+       auto  result = std::apply (compute_fun, loaded_inputs);
123+       store_common_to_out (result, &data_out[i * out_element_size]);
124+     }
79125  }
80126}
127+ } //  namespace internal
128+ 
129+ template  <typename  CTYPE_COMMON, const  char * op_name, typename  Op>
130+ inline  void  apply_unitensor_elementwise_fn (
131+     const  Op& compute_fun,
132+     KernelRuntimeContext& ctx,
133+     const  Tensor& a,
134+     SupportedTensorDtypes a_dtypes,
135+     const  Tensor& out,
136+     SupportedTensorDtypes out_dtypes) {
137+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
138+       compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
139+ }
81140
82141/* *
83142 * Useful for bi-tensor elementwise operators. For each element of the inputs, 
@@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn(
94153    SupportedTensorDtypes b_dtypes,
95154    const  Tensor& out,
96155    SupportedTensorDtypes out_dtypes) {
97-   constexpr  auto  compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
98- 
99-   ET_KERNEL_CHECK (
156+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
157+       compute_fun,
100158      ctx,
101-       (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
102-        internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
103-        internal::check_tensor_dtype (out, out_dtypes, compute_type)),
104-       InvalidArgument, );
105- 
106-   const  bool  a_is_broadcasted = !out.sizes ().equals (a.sizes ());
107-   const  bool  b_is_broadcasted = !out.sizes ().equals (b.sizes ());
108-   const  bool  any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
109- 
110-   const  auto  load_a_to_common =
111-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
112-   const  auto  load_b_to_common =
113-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
114-   const  auto  store_common_to_out =
115-       internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
116-           out, out_dtypes);
117-   const  char * const  data_a = reinterpret_cast <const  char *>(a.const_data_ptr ());
118-   const  char * const  data_b = reinterpret_cast <const  char *>(b.const_data_ptr ());
119-   const  auto  a_element_size = a.element_size ();
120-   const  auto  b_element_size = b.element_size ();
121-   const  auto  out_element_size = out.element_size ();
122-   char * const  data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
123- 
124-   auto  out_numel = out.numel ();
125-   if  (any_is_broadcasted) {
126-     for  (const  auto  [out_index, a_index, b_index] :
127-          BroadcastIndexesRange<2 >(out, a, b)) {
128-       auto  result = compute_fun (
129-           load_a_to_common (&data_a[a_index * a_element_size]),
130-           load_b_to_common (&data_b[b_index * b_element_size]));
131-       store_common_to_out (result, &data_out[out_index * out_element_size]);
132-     }
133-   } else  {
134-     for  (const  auto  i : c10::irange (out_numel)) {
135-       size_t  a_linear_index = i;
136-       size_t  b_linear_index = i;
137- 
138-       auto  result = compute_fun (
139-           load_a_to_common (&data_a[a_linear_index * a_element_size]),
140-           load_b_to_common (&data_b[b_linear_index * b_element_size]));
141-       store_common_to_out (result, &data_out[i * out_element_size]);
142-     }
143-   }
159+       out,
160+       out_dtypes,
161+       std::make_pair (&a, a_dtypes),
162+       std::make_pair (&b, b_dtypes));
144163}
145164
146165/* *
@@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn(
175194    SupportedTensorDtypes c_dtypes,
176195    const  Tensor& out,
177196    SupportedTensorDtypes out_dtypes) {
178-   constexpr  auto  compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179- 
180-   ET_KERNEL_CHECK (
197+   internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
198+       compute_fun,
181199      ctx,
182-       (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
183-        internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
184-        internal::check_tensor_dtype (c, c_dtypes, compute_type) &&
185-        internal::check_tensor_dtype (out, out_dtypes, compute_type)),
186-       InvalidArgument, );
187- 
188-   const  bool  a_is_broadcasted = !out.sizes ().equals (a.sizes ());
189-   const  bool  b_is_broadcasted = !out.sizes ().equals (b.sizes ());
190-   const  bool  c_is_broadcasted = !out.sizes ().equals (c.sizes ());
191-   const  bool  any_is_broadcasted =
192-       (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193- 
194-   const  auto  load_a_to_common =
195-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196-   const  auto  load_b_to_common =
197-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198-   const  auto  load_c_to_common =
199-       internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200-   const  auto  store_common_to_out =
201-       internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202-           out, out_dtypes);
203-   const  char * const  data_a = reinterpret_cast <const  char *>(a.const_data_ptr ());
204-   const  char * const  data_b = reinterpret_cast <const  char *>(b.const_data_ptr ());
205-   const  char * const  data_c = reinterpret_cast <const  char *>(c.const_data_ptr ());
206-   const  auto  a_element_size = a.element_size ();
207-   const  auto  b_element_size = b.element_size ();
208-   const  auto  c_element_size = c.element_size ();
209-   const  auto  out_element_size = out.element_size ();
210-   char * const  data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
211- 
212-   auto  out_numel = out.numel ();
213-   if  (any_is_broadcasted) {
214-     for  (const  auto  [out_index, a_index, b_index, c_index] :
215-          BroadcastIndexesRange<3 >(out, a, b, c)) {
216-       auto  result = compute_fun (
217-           load_a_to_common (&data_a[a_index * a_element_size]),
218-           load_b_to_common (&data_b[b_index * b_element_size]),
219-           load_c_to_common (&data_c[c_index * c_element_size]));
220-       store_common_to_out (result, &data_out[out_index * out_element_size]);
221-     }
222-   } else  {
223-     for  (const  auto  i : c10::irange (out_numel)) {
224-       size_t  a_linear_index = i;
225-       size_t  b_linear_index = i;
226-       size_t  c_linear_index = i;
227- 
228-       auto  result = compute_fun (
229-           load_a_to_common (&data_a[a_linear_index * a_element_size]),
230-           load_b_to_common (&data_b[b_linear_index * b_element_size]),
231-           load_c_to_common (&data_c[c_linear_index * c_element_size]));
232-       store_common_to_out (result, &data_out[i * out_element_size]);
233-     }
234-   }
200+       out,
201+       out_dtypes,
202+       std::make_pair (&a, a_dtypes),
203+       std::make_pair (&b, b_dtypes),
204+       std::make_pair (&c, c_dtypes));
235205}
236206
237207inline  ScalarType get_compute_type (ScalarType& common_type) {
0 commit comments