@@ -58,17 +58,17 @@ Tensor& div_out(
5858 static constexpr const char op_name[] = " div.out" ;
5959
6060 ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61- utils::apply_bitensor_elementwise_fn<
62- CTYPE_COMPUTE,
63- op_name,
64- utils::SupportedTensorDtypes::FLOATHBF16>(
65- [](const auto val_a, const auto val_b) { return val_a / val_b; },
61+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62+ [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63+ return val_a / val_b;
64+ },
6665 ctx,
6766 a,
6867 utils::SupportedTensorDtypes::REALHBBF16,
6968 b,
7069 utils::SupportedTensorDtypes::REALHBBF16,
71- out);
70+ out,
71+ utils::SupportedTensorDtypes::FLOATHBF16);
7272 });
7373
7474 return out;
@@ -122,13 +122,9 @@ Tensor& div_out_mode(
122122 bool div_by_zero_error = false ;
123123
124124 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125- utils::apply_bitensor_elementwise_fn<
126- CTYPE_COMPUTE,
127- op_name,
128- utils::SupportedTensorDtypes::REALHBF16>(
125+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
129126 [mode_is_trunc, &div_by_zero_error](
130127 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131- // TODO: rewrite this to be vectorization-capable.
132128 if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
133129 if (val_b == 0 ) {
134130 div_by_zero_error = true ;
@@ -150,7 +146,8 @@ Tensor& div_out_mode(
150146 utils::SupportedTensorDtypes::REALHBBF16,
151147 b,
152148 utils::SupportedTensorDtypes::REALHBBF16,
153- out);
149+ out,
150+ utils::SupportedTensorDtypes::REALHBF16);
154151 });
155152
156153 ET_KERNEL_CHECK_MSG (
@@ -191,15 +188,13 @@ Tensor& div_scalar_out(
191188
192189 ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
193190 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
194- utils::apply_unitensor_elementwise_fn<
195- CTYPE_COMPUTE,
196- op_name,
197- utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198- [val_b](const auto val_a) { return val_a / val_b; },
191+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192+ [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
199193 ctx,
200194 a,
201195 utils::SupportedTensorDtypes::REALHBBF16,
202- out);
196+ out,
197+ utils::SupportedTensorDtypes::SAME_AS_COMMON);
203198 });
204199
205200 return out;
0 commit comments