@@ -192,7 +192,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
192192 return normalized_tensor_size;
193193}
194194
195- template <const char * op_name , typename Op>
195+ template <typename CTYPE , typename Op>
196196Tensor& handle_last_dim_broadcast_elementwise (
197197 KernelRuntimeContext& ctx,
198198 const Op& vec_fun,
@@ -221,32 +221,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
221221 " Failed to resize output tensor." );
222222 const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
223223 const auto broadcast_size = out.size (out.dim () - 1 );
224- ET_SWITCH_REALB_TYPES (out_type, ctx, op_name, CTYPE, [&]() {
225- using Vec = executorch::vec::Vectorized<CTYPE>;
226- Vec alpha_val_vec;
227- if (alpha.has_value ()) {
228- CTYPE alpha_val;
229- ET_KERNEL_CHECK (
230- ctx,
231- native::utils::extract_scalar (alpha.value (), &alpha_val),
232- InvalidArgument, );
233- alpha_val_vec = Vec (alpha_val);
234- }
235- auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
236- return vec_fun (a, b, alpha_val_vec);
237- };
238- executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
239- vec_fun_alpha,
240- out.mutable_data_ptr <CTYPE>(),
241- lhs->const_data_ptr <CTYPE>(),
242- rhs->const_data_ptr <CTYPE>(),
243- outer_size,
244- broadcast_size);
245- });
224+ executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
225+ vec_fun,
226+ out.mutable_data_ptr <CTYPE>(),
227+ lhs->const_data_ptr <CTYPE>(),
228+ rhs->const_data_ptr <CTYPE>(),
229+ outer_size,
230+ broadcast_size);
246231 return out;
247232}
248233
249- template <const char * op_name , typename Op>
234+ template <typename CTYPE , typename Op>
250235Tensor& handle_broadcast_elementwise (
251236 KernelRuntimeContext& ctx,
252237 const Op& vec_fun,
@@ -259,11 +244,10 @@ Tensor& handle_broadcast_elementwise(
259244 ElementwiseOptimizedPath::kBroadcastLastDim ) ||
260245 (selected_optimized_path ==
261246 ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
262- return handle_last_dim_broadcast_elementwise<op_name >(
263- ctx, vec_fun, a, b, out, selected_optimized_path, alpha );
247+ return handle_last_dim_broadcast_elementwise<CTYPE >(
248+ ctx, vec_fun, a, b, out, selected_optimized_path);
264249 }
265250
266- ScalarType out_type = out.scalar_type ();
267251 const Tensor* lhs;
268252 const Tensor* rhs;
269253 if ((selected_optimized_path ==
@@ -306,30 +290,14 @@ Tensor& handle_broadcast_elementwise(
306290 broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
307291 inner_size = lhs->sizes ()[lhs->dim () - 1 ];
308292 }
309- ET_SWITCH_REALB_TYPES (out_type, ctx, op_name, CTYPE, [&]() {
310- using Vec = executorch::vec::Vectorized<CTYPE>;
311- Vec alpha_val_vec;
312- if (alpha.has_value ()) {
313- CTYPE alpha_val;
314- ET_KERNEL_CHECK (
315- ctx,
316- native::utils::extract_scalar (alpha.value (), &alpha_val),
317- InvalidArgument, );
318- alpha_val_vec = Vec (alpha_val);
319- }
320- auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
321- return vec_fun (a, b, alpha_val_vec);
322- };
323- executorch::vec::
324- broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype (vec_fun_alpha)>(
325- vec_fun_alpha,
326- out.mutable_data_ptr <CTYPE>(),
327- lhs->const_data_ptr <CTYPE>(),
328- rhs->const_data_ptr <CTYPE>(),
329- outer_size,
330- broadcast_size,
331- inner_size);
332- });
293+ executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
294+ vec_fun,
295+ out.mutable_data_ptr <CTYPE>(),
296+ lhs->const_data_ptr <CTYPE>(),
297+ rhs->const_data_ptr <CTYPE>(),
298+ outer_size,
299+ broadcast_size,
300+ inner_size);
333301 return out;
334302}
335303} // namespace executor
0 commit comments