99#include < executorch/backends/cadence/hifi/kernels/kernels.h>
1010#include < executorch/kernels/portable/cpu/scalar_utils.h>
1111#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
12+ #include < executorch/kernels/portable/cpu/util/dtype_util.h>
13+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1214#include < executorch/kernels/portable/cpu/util/functional_util.h>
1315#include < executorch/kernels/portable/cpu/util/math_util.h>
1416#include < executorch/runtime/kernel/kernel_includes.h>
@@ -134,25 +136,26 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
134136 InvalidArgument,
135137 out);
136138
137- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.out" , CTYPE_A, [&]() {
138- ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out" , CTYPE_B, [&]() {
139- ET_SWITCH_FLOAT_TYPES (common_type, ctx, " div.out" , CTYPE_IN, [&]() {
140- ET_SWITCH_FLOAT_TYPES (out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
141- torch::executor::
142- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
143- [](const CTYPE_A val_a, const CTYPE_B val_b) {
144- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
145- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
146- CTYPE_IN value = a_casted / b_casted;
147-
148- return static_cast <CTYPE_OUT>(value);
149- },
150- a,
151- b,
152- out);
153- });
154- });
155- });
139+ // Compute Dtype
140+ ScalarType compute_type =
141+ torch::executor::native::utils::get_compute_type (common_type);
142+
143+ // @lint-ignore CLANGTIDY facebook-hte-CArray
144+ static constexpr const char op_name[] = " div.out" ;
145+
146+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
147+ torch::executor::native::utils::
148+ apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
149+ [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
150+ return val_a / val_b;
151+ },
152+ ctx,
153+ a,
154+ torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
155+ b,
156+ torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
157+ out,
158+ torch::executor::native::utils::SupportedTensorDtypes::FLOATHBF16);
156159 });
157160
158161 return out;
@@ -254,35 +257,59 @@ Tensor& div_out_mode(
254257 return out;
255258 }
256259
257- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.out_mode" , CTYPE_A, [&]() {
258- ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out_mode" , CTYPE_B, [&]() {
259- ET_SWITCH_FLOAT_TYPES (common_type, ctx, " div.out_mode" , CTYPE_IN, [&]() {
260- ET_SWITCH_REAL_TYPES (out_type, ctx, " div.out_mode" , CTYPE_OUT, [&]() {
261- torch::executor::
262- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
263- [mode](const CTYPE_A val_a, const CTYPE_B val_b) {
264- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
265- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
266- CTYPE_IN value = a_casted / b_casted;
267- if (mode.has_value () && mode.value () == " trunc" ) {
268- value = std::trunc (value);
269- } else if (mode.has_value () && mode.value () == " floor" ) {
270- value = std::floor (value);
271- }
272- return static_cast <CTYPE_OUT>(value);
273- },
274- a,
275- b,
276- out);
277- });
278- });
279- });
260+ bool div_by_zero_error = false ;
261+ const bool mode_is_trunc = (mode.has_value () && mode.value () == " trunc" );
262+ // Compute Dtype
263+ ScalarType compute_type =
264+ torch::executor::native::utils::get_compute_type (common_type);
265+
266+ // @lint-ignore CLANGTIDY facebook-hte-CArray
267+ static constexpr const char op_name[] = " div.out" ;
268+
269+ ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
270+ torch::executor::native::utils::
271+ apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
272+ [mode_is_trunc, &div_by_zero_error](
273+ const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
274+ if (executorch::runtime::is_integral_type<
275+ CTYPE_COMPUTE,
276+ /* includeBool=*/ true >::value) {
277+ if (val_b == 0 ) {
278+ div_by_zero_error = true ;
279+ return static_cast <CTYPE_COMPUTE>(0 );
280+ }
281+ }
282+ CTYPE_COMPUTE value = val_a / val_b;
283+ if (mode_is_trunc) {
284+ value = std::trunc (value);
285+ } else {
286+ // We established above that the mode is either trunc or floor,
287+ // so it must be floor.
288+ value =
289+ torch::executor::native::utils::floor_divide (val_a, val_b);
290+ }
291+ return value;
292+ },
293+ ctx,
294+ a,
295+ torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
296+ b,
297+ torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
298+ out,
299+ torch::executor::native::utils::SupportedTensorDtypes::REALHBF16);
280300 });
281301
302+ ET_KERNEL_CHECK_MSG (
303+ ctx,
304+ !div_by_zero_error,
305+ InvalidArgument,
306+ out,
307+ " Div mode operation encountered integer division by zero" );
308+
282309 return out;
283310}
284311
285312} // namespace native
286313} // namespace HiFi
287314} // namespace impl
288- } // namespace cadence
315+ } // namespace cadence
0 commit comments