@@ -68,114 +68,6 @@ template <
6868struct MulInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
6969 : public ReportCanCastBug {};
7070
71- Tensor& handle_last_dim_broadcast (
72- KernelRuntimeContext& ctx,
73- const Tensor& a,
74- const Tensor& b,
75- Tensor& out,
76- const ElementwiseOptimizedPath selected_optimized_path) {
77- ScalarType out_type = out.scalar_type ();
78- const Tensor* lhs;
79- const Tensor* rhs;
80- if (selected_optimized_path ==
81- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ) {
82- lhs = &b;
83- rhs = &a;
84- } else {
85- lhs = &a;
86- rhs = &b;
87- }
88- auto error = resize_tensor (out, lhs->sizes ());
89- ET_KERNEL_CHECK_MSG (
90- ctx,
91- error == Error::Ok,
92- InvalidArgument,
93- out,
94- " Failed to resize output tensor." );
95- const size_t outer_size = getLeadingDims (out, out.dim () - 1 );
96- const auto broadcast_size = out.size (out.dim () - 1 );
97- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
98- using Vec = executorch::vec::Vectorized<CTYPE>;
99- executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100- [](Vec x, Vec y) { return x * y; },
101- out.mutable_data_ptr <CTYPE>(),
102- lhs->const_data_ptr <CTYPE>(),
103- rhs->const_data_ptr <CTYPE>(),
104- outer_size,
105- broadcast_size);
106- });
107- return out;
108- }
109-
110- Tensor& handle_broadcast_mul (
111- KernelRuntimeContext& ctx,
112- const Tensor& a,
113- const Tensor& b,
114- Tensor& out,
115- const ElementwiseOptimizedPath selected_optimized_path) {
116- if ((selected_optimized_path ==
117- ElementwiseOptimizedPath::kBroadcastLastDim ) ||
118- (selected_optimized_path ==
119- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments )) {
120- return handle_last_dim_broadcast (ctx, a, b, out, selected_optimized_path);
121- }
122-
123- ScalarType out_type = out.scalar_type ();
124- const Tensor* lhs;
125- const Tensor* rhs;
126- if ((selected_optimized_path ==
127- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) ||
128- (selected_optimized_path ==
129- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments )) {
130- lhs = &b;
131- rhs = &a;
132- } else {
133- // Catch failure to update logic when adding new broadcasting possibility.
134- ET_DCHECK (
135- (selected_optimized_path ==
136- ElementwiseOptimizedPath::kBroadcast2dBy1d ) ||
137- (selected_optimized_path ==
138- ElementwiseOptimizedPath::kBroadcastNdByNd ));
139- lhs = &a;
140- rhs = &b;
141- }
142- auto error = resize_tensor (out, lhs->sizes ());
143- ET_KERNEL_CHECK_MSG (
144- ctx,
145- error == Error::Ok,
146- InvalidArgument,
147- out,
148- " Failed to resize output tensor." );
149- int64_t outer_size = 1 ;
150- int64_t broadcast_size;
151- int64_t inner_size;
152- if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd ) ||
153- (selected_optimized_path ==
154- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments )) {
155- int32_t broadcast_dim = internal::get_broadcast_dim (*lhs, *rhs);
156- int32_t broadcast_dim_lhs = lhs->dim () + broadcast_dim;
157- auto normalized_tensor_size_lhs =
158- get_normalized_tensor_size (*lhs, broadcast_dim_lhs);
159- outer_size = normalized_tensor_size_lhs[0 ];
160- broadcast_size = normalized_tensor_size_lhs[1 ];
161- inner_size = normalized_tensor_size_lhs[2 ];
162- } else {
163- broadcast_size = lhs->sizes ()[lhs->dim () - 2 ];
164- inner_size = lhs->sizes ()[lhs->dim () - 1 ];
165- }
166- ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
167- using Vec = executorch::vec::Vectorized<CTYPE>;
168- executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169- [](Vec x, Vec y) { return x * y; },
170- out.mutable_data_ptr <CTYPE>(),
171- lhs->const_data_ptr <CTYPE>(),
172- rhs->const_data_ptr <CTYPE>(),
173- outer_size,
174- broadcast_size,
175- inner_size);
176- });
177- return out;
178- }
17971} // namespace
18072
18173Tensor& opt_mul_out (
@@ -238,7 +130,9 @@ Tensor& opt_mul_out(
238130 out.numel ());
239131 });
240132 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
241- return handle_broadcast_mul (ctx, a, b, out, selected_optimized_path);
133+ auto mul_lambda = [](auto x, auto y) { return x * y; };
134+ return torch::executor::handle_broadcast_elementwise (
135+ ctx, mul_lambda, a, b, out, selected_optimized_path);
242136 } else {
243137 ScalarType common_type =
244138 promoteTypes (a_type, b_type, /* half_to_float*/ true );
0 commit comments