@@ -120,48 +120,36 @@ Tensor& opt_div_out(
120120 out.numel ());
121121 });
122122 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
123- const Tensor* lhs;
124- const Tensor* rhs;
123+ // Reason for using alpha is becasuse handle_broadcast_elementwise
124+ // is used for add and sub as well:
125125 if (selected_optimized_path ==
126- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
127- lhs = &b;
128- rhs = &a;
126+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
127+ selected_optimized_path ==
128+ ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
129+ selected_optimized_path ==
130+ ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
131+ // This behavior is a bit confusing.
132+ // Reason we swap out args here is because handle_broadcast_elementwise
133+ // handles this selected_optimized_path option a bit differently.
134+ // This should really be resoled in handle_broadcast_elementwise.
135+ // However, the current blocker is that handle_broadcast_elementwise tries
136+ // to be agnostic of op. This should be fixed, likely by moving lambda
137+ // creation to handle_broadcast_elementwise and it be aware of which op is
138+ // being executed.
139+ auto div_lambda = [](auto x, auto y, auto alpha) {
140+ (void )alpha;
141+ return y / x;
142+ };
143+ return torch::executor::handle_broadcast_elementwise (
144+ ctx, div_lambda, a, b, out, selected_optimized_path);
129145 } else {
130- // Catch failure to update logic when subing new broadcasting possibility.
131- ET_DCHECK (
132- selected_optimized_path ==
133- ElementwiseOptimizedPath:: kBroadcast2dBy1d ) ;
134- lhs = &a;
135- rhs = &b ;
146+ auto div_lambda = []( auto x, auto y, auto alpha) {
147+ ( void )alpha;
148+ return x / y;
149+ } ;
150+ return torch::executor::handle_broadcast_elementwise (
151+ ctx, div_lambda, a, b, out, selected_optimized_path) ;
136152 }
137- auto error = resize_tensor (out, lhs->sizes ());
138- ET_KERNEL_CHECK_MSG (
139- ctx,
140- error == Error::Ok,
141- InvalidArgument,
142- out,
143- " Failed to resize output tensor." );
144- ET_SWITCH_REALB_TYPES (out_type, ctx, " sub.out" , CTYPE, [&]() {
145- using Vec = executorch::vec::Vectorized<CTYPE>;
146- if (selected_optimized_path ==
147- ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
148- executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149- [](Vec x, Vec y) { return y / x; },
150- out.mutable_data_ptr <CTYPE>(),
151- lhs->const_data_ptr <CTYPE>(),
152- rhs->const_data_ptr <CTYPE>(),
153- lhs->sizes ()[lhs->dim () - 2 ],
154- lhs->sizes ()[lhs->dim () - 1 ]);
155- } else {
156- executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157- [](Vec x, Vec y) { return x / y; },
158- out.mutable_data_ptr <CTYPE>(),
159- lhs->const_data_ptr <CTYPE>(),
160- rhs->const_data_ptr <CTYPE>(),
161- lhs->sizes ()[lhs->dim () - 2 ],
162- lhs->sizes ()[lhs->dim () - 1 ]);
163- }
164- });
165153 } else {
166154 ScalarType common_type = get_compute_type (a_type, b_type);
167155 ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
0 commit comments