@@ -191,7 +191,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
191191 return normalized_tensor_size;
192192}
193193
194- template <typename Op>
194+ template <typename CTYPE, typename Op>
195195Tensor& handle_last_dim_broadcast_elementwise (
196196 KernelRuntimeContext& ctx,
197197 const Op& vec_fun,
@@ -219,19 +219,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
219219 " Failed to resize output tensor." );
220220 const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
221221 const auto broadcast_size = out.size (out.dim () - 1 );
222- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
223- executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
224- vec_fun,
225- out.mutable_data_ptr <CTYPE>(),
226- lhs->const_data_ptr <CTYPE>(),
227- rhs->const_data_ptr <CTYPE>(),
228- outer_size,
229- broadcast_size);
230- });
222+ executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
223+ vec_fun,
224+ out.mutable_data_ptr <CTYPE>(),
225+ lhs->const_data_ptr <CTYPE>(),
226+ rhs->const_data_ptr <CTYPE>(),
227+ outer_size,
228+ broadcast_size);
231229 return out;
232230}
233231
234- template <typename Op>
232+ template <typename CTYPE, typename Op>
235233Tensor& handle_broadcast_elementwise (
236234 KernelRuntimeContext& ctx,
237235 const Op& vec_fun,
@@ -243,11 +241,10 @@ Tensor& handle_broadcast_elementwise(
243241 ElementwiseOptimizedPath::kBroadcastLastDim ) ||
244242 (selected_optimized_path ==
245243 ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
246- return handle_last_dim_broadcast_elementwise (
244+ return handle_last_dim_broadcast_elementwise<CTYPE> (
247245 ctx, vec_fun, a, b, out, selected_optimized_path);
248246 }
249247
250- ScalarType out_type = out.scalar_type ();
251248 const Tensor* lhs;
252249 const Tensor* rhs;
253250 if ((selected_optimized_path ==
@@ -290,16 +287,14 @@ Tensor& handle_broadcast_elementwise(
290287 broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
291288 inner_size = lhs->sizes ()[lhs->dim () - 1 ];
292289 }
293- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
294- executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
295- vec_fun,
296- out.mutable_data_ptr <CTYPE>(),
297- lhs->const_data_ptr <CTYPE>(),
298- rhs->const_data_ptr <CTYPE>(),
299- outer_size,
300- broadcast_size,
301- inner_size);
302- });
290+ executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
291+ vec_fun,
292+ out.mutable_data_ptr <CTYPE>(),
293+ lhs->const_data_ptr <CTYPE>(),
294+ rhs->const_data_ptr <CTYPE>(),
295+ outer_size,
296+ broadcast_size,
297+ inner_size);
303298 return out;
304299}
305300} // namespace executor
0 commit comments