@@ -120,48 +120,37 @@ Tensor& opt_div_out(
120120          out.numel ());
121121    });
122122  } else  if  (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
123-     const  Tensor* lhs;
124-     const  Tensor* rhs;
123+     //  Reason for using alpha is becasuse handle_broadcast_elementwise
124+     //  is used for add and sub as well:
125+     static  constexpr  const  char  op_name[] = " mul.out"  ;
125126    if  (selected_optimized_path ==
126-         ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
127-       lhs = &b;
128-       rhs = &a;
127+             ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments  ||
128+         selected_optimized_path ==
129+             ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments  ||
130+         selected_optimized_path ==
131+             ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments ) {
132+       //  This behavior is a bit confusing.
133+       //  Reason we swap out args here is because handle_broadcast_elementwise
134+       //  handles this selected_optimized_path option a bit differently.
135+       //  This should really be resoled in handle_broadcast_elementwise.
136+       //  However, the current blocker is that handle_broadcast_elementwise tries
137+       //  to be agnostic of op. This should be fixed, likely by moving lambda
138+       //  creation to handle_broadcast_elementwise and it be aware of which op is
139+       //  being executed.
140+       auto  div_lambda = [](auto  x, auto  y, auto  alpha) {
141+         [[maybe_unused]] alpha;
142+         return  y / x;
143+       };
144+       return  torch::executor::handle_broadcast_elementwise<op_name>(
145+           ctx, div_lambda, a, b, out, selected_optimized_path);
129146    } else  {
130-       //  Catch failure to update logic when subing new broadcasting possibility. 
131-       ET_DCHECK ( 
132-           selected_optimized_path == 
133-           ElementwiseOptimizedPath:: kBroadcast2dBy1d ) ;
134-       lhs = &a; 
135-       rhs = &b ;
147+       auto  div_lambda = []( auto  x,  auto  y,  auto  alpha) { 
148+         [[maybe_unused]] alpha; 
149+         return  x / y; 
150+       } ;
151+       return  torch::executor::handle_broadcast_elementwise<op_name>( 
152+           ctx, div_lambda, a, b, out, selected_optimized_path) ;
136153    }
137-     auto  error = resize_tensor (out, lhs->sizes ());
138-     ET_KERNEL_CHECK_MSG (
139-         ctx,
140-         error == Error::Ok,
141-         InvalidArgument,
142-         out,
143-         " Failed to resize output tensor."  );
144-     ET_SWITCH_REALB_TYPES (out_type, ctx, " sub.out"  , CTYPE, [&]() {
145-       using  Vec = executorch::vec::Vectorized<CTYPE>;
146-       if  (selected_optimized_path ==
147-           ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
148-         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149-             [](Vec x, Vec y) { return  y / x; },
150-             out.mutable_data_ptr <CTYPE>(),
151-             lhs->const_data_ptr <CTYPE>(),
152-             rhs->const_data_ptr <CTYPE>(),
153-             lhs->sizes ()[lhs->dim () - 2 ],
154-             lhs->sizes ()[lhs->dim () - 1 ]);
155-       } else  {
156-         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157-             [](Vec x, Vec y) { return  x / y; },
158-             out.mutable_data_ptr <CTYPE>(),
159-             lhs->const_data_ptr <CTYPE>(),
160-             rhs->const_data_ptr <CTYPE>(),
161-             lhs->sizes ()[lhs->dim () - 2 ],
162-             lhs->sizes ()[lhs->dim () - 1 ]);
163-       }
164-     });
165154  } else  {
166155    ScalarType common_type = get_compute_type (a_type, b_type);
167156    ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
0 commit comments