@@ -49,38 +49,8 @@ enum class ElementwiseOptimizedPath {
4949 kBroadcastLastDimReverseArguments ,
5050};
5151
52- enum class BinaryOpType {
53- kAdd ,
54- kSub ,
55- kMul ,
56- kDiv ,
57- };
58-
5952namespace internal {
6053
61- template <BinaryOpType op_type>
62- struct BinaryOpTypeName ;
63-
64- template <>
65- struct BinaryOpTypeName <BinaryOpType::kAdd > {
66- static constexpr char kName [] = " add.out" ;
67- };
68-
69- template <>
70- struct BinaryOpTypeName <BinaryOpType::kSub > {
71- static constexpr char kName [] = " sub.out" ;
72- };
73-
74- template <>
75- struct BinaryOpTypeName <BinaryOpType::kMul > {
76- static constexpr char kName [] = " mul.out" ;
77- };
78-
79- template <>
80- struct BinaryOpTypeName <BinaryOpType::kDiv > {
81- static constexpr char kName [] = " div.out" ;
82- };
83-
8454/*
8555 Given two tensors, this function returns the broadcast dim if it exists.
8656 Returns 0 if no broadcast dim is found.
@@ -222,15 +192,15 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
222192 return normalized_tensor_size;
223193}
224194
225- template <BinaryOpType op_type , typename Op>
195+ template <const char * op_name , typename Op>
226196Tensor& handle_last_dim_broadcast_elementwise (
227197 KernelRuntimeContext& ctx,
228198 const Op& vec_fun,
229199 const Tensor& a,
230200 const Tensor& b,
231201 Tensor& out,
232202 const ElementwiseOptimizedPath selected_optimized_path,
233- executorch::aten::optional<Scalar>& alpha = {}) {
203+ const executorch::aten::optional<Scalar>& alpha = {}) {
234204 ScalarType out_type = out.scalar_type ();
235205 const Tensor* lhs;
236206 const Tensor* rhs;
@@ -251,11 +221,11 @@ Tensor& handle_last_dim_broadcast_elementwise(
251221 " Failed to resize output tensor." );
252222 const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
253223 const auto broadcast_size = out.size (out.dim () - 1 );
254- ET_SWITCH_REALB_TYPES (out_type, ctx, internal::BinaryOpTypeName<op_type>:: kName , CTYPE, [&]() {
224+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
255225 using Vec = executorch::vec::Vectorized<CTYPE>;
256- CTYPE alpha_val;
257- Vec alpha_val_vec (alpha_val);
226+ Vec alpha_val_vec;
258227 if (alpha.has_value ()) {
228+ CTYPE alpha_val;
259229 ET_KERNEL_CHECK (
260230 ctx,
261231 native::utils::extract_scalar (alpha.value (), &alpha_val),
@@ -276,20 +246,20 @@ Tensor& handle_last_dim_broadcast_elementwise(
276246 return out;
277247}
278248
279- template <BinaryOpType op_type , typename Op>
249+ template <const char * op_name , typename Op>
280250Tensor& handle_broadcast_elementwise (
281251 KernelRuntimeContext& ctx,
282252 const Op& vec_fun,
283253 const Tensor& a,
284254 const Tensor& b,
285255 Tensor& out,
286256 const ElementwiseOptimizedPath selected_optimized_path,
287- executorch::aten::optional<Scalar> alpha = {}) {
257+ const executorch::aten::optional<Scalar>& alpha = {}) {
288258 if ((selected_optimized_path ==
289259 ElementwiseOptimizedPath::kBroadcastLastDim ) ||
290260 (selected_optimized_path ==
291261 ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
292- return handle_last_dim_broadcast_elementwise<op_type >(
262+ return handle_last_dim_broadcast_elementwise<op_name >(
293263 ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
294264 }
295265
@@ -336,11 +306,11 @@ Tensor& handle_broadcast_elementwise(
336306 broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
337307 inner_size = lhs->sizes ()[lhs->dim () - 1 ];
338308 }
339- ET_SWITCH_REALB_TYPES (out_type, ctx, internal::BinaryOpTypeName<op_type>:: kName , CTYPE, [&]() {
309+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
340310 using Vec = executorch::vec::Vectorized<CTYPE>;
341- CTYPE alpha_val;
342311 Vec alpha_val_vec;
343312 if (alpha.has_value ()) {
313+ CTYPE alpha_val;
344314 ET_KERNEL_CHECK (
345315 ctx,
346316 native::utils::extract_scalar (alpha.value (), &alpha_val),
0 commit comments