Skip to content

Commit cdca6ca

Browse files
For div operator and for generic case, added get_compute type function to support integer inputs and float output
1 parent d8cf517 commit cdca6ca

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

backends/cadence/fusion_g3/operators/op_div.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ Tensor& div_out(KernelRuntimeContext& ctx,
220220
}
221221
else
222222
{
223+
ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
224+
ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type);
223225

224226
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
225227
torch::executor::native::utils::
@@ -540,6 +542,8 @@ Tensor& div_scalar_out(KernelRuntimeContext& ctx,
540542
}
541543
else
542544
{
545+
ScalarType common_type = executorch::runtime::isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float;
546+
ScalarType compute_type = torch::executor::native::utils::get_compute_type(common_type);
543547
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
544548
const CTYPE_COMPUTE val_b = torch::executor::native::utils::
545549
scalar_to<CTYPE_COMPUTE>(b);

0 commit comments

Comments
 (0)