@@ -115,45 +115,65 @@ Tensor& opt_add_sub_out_impl(
115115 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
116116 // Cannot apply the trick of -alpha here because alpha is Scalar without
117117 // support for - operator. At least not right now.
118- if constexpr (is_sub) {
119- if (selected_optimized_path ==
120- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
121- selected_optimized_path ==
122- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
123- selected_optimized_path ==
124- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
125- auto add_lambda = [](auto x, auto y, auto alpha_val) {
126- return y - alpha_val * x;
127- };
128- return torch::executor::handle_broadcast_elementwise<op_name>(
129- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
130- } else {
131- auto add_lambda = [](auto x, auto y, auto alpha_val) {
132- return x - alpha_val * y;
133- };
134- return torch::executor::handle_broadcast_elementwise<op_name>(
135- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
136- }
137- } else {
138- if (selected_optimized_path ==
139- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
140- selected_optimized_path ==
141- ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
142- selected_optimized_path ==
143- ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
144- auto add_lambda = [](auto x, auto y, auto alpha_val) {
145- return y + alpha_val * x;
146- };
147- return torch::executor::handle_broadcast_elementwise<op_name>(
148- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
118+ ScalarType out_type = out.scalar_type ();
119+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name, CTYPE, [&]() {
120+ CTYPE alpha_val;
121+ ET_KERNEL_CHECK_MSG (
122+ ctx,
123+ torch::executor::native::utils::extract_scalar (alpha, &alpha_val),
124+ InvalidArgument,
125+ out,
126+ " Failed to extract scalar alpha." );
127+ using Vec = executorch::vec::Vectorized<CTYPE>;
128+ Vec alpha_val_vec (alpha_val);
129+ if constexpr (is_sub) {
130+ if (selected_optimized_path ==
131+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
132+ selected_optimized_path ==
133+ ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
134+ selected_optimized_path ==
135+ ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
136+ auto add_lambda = [&alpha_val_vec](auto x, auto y) {
137+ return y - alpha_val_vec * x;
138+ };
139+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
140+ ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
141+ } else {
142+ auto add_lambda = [&alpha_val_vec](auto x, auto y) {
143+ return x - alpha_val_vec * y;
144+ };
145+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
146+ ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
147+ }
149148 } else {
150- auto add_lambda = [](auto x, auto y, auto alpha_val) {
151- return x + alpha_val * y;
152- };
153- return torch::executor::handle_broadcast_elementwise<op_name>(
154- ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
149+ if (selected_optimized_path ==
150+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
151+ selected_optimized_path ==
152+ ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
153+ selected_optimized_path ==
154+ ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
155+ // Reason we swap out args here is because
156+ // handle_broadcast_elementwise handles this selected_optimized_path
157+ // option a bit differently. This should really be resolved in
158+ // handle_broadcast_elementwise. However, the current blocker is that
159+ // handle_broadcast_elementwise tries to be agnostic of op. This
160+ // should be fixed, likely by moving lambda creation to
161+ // handle_broadcast_elementwise and it be aware of which op is being
162+ // executed.
163+ auto add_lambda = [&alpha_val_vec](auto x, auto y) {
164+ return y + alpha_val_vec * x;
165+ };
166+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
167+ ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
168+ } else {
169+ auto add_lambda = [&alpha_val_vec](auto x, auto y) {
170+ return x + alpha_val_vec * y;
171+ };
172+ return torch::executor::handle_broadcast_elementwise<CTYPE>(
173+ ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
174+ }
155175 }
156- }
176+ });
157177 } else {
158178 ScalarType common_type =
159179 promoteTypes (a_type, b_type, /* half_to_float*/ true );
0 commit comments