diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index abdeeb73453..babcf205812 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -61,7 +61,7 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) target_link_libraries( - optimized_kernels PRIVATE executorch_core cpublas extension_threadpool + optimized_kernels PRIVATE executorch_core portable_kernels cpublas extension_threadpool ) target_compile_options(optimized_kernels PUBLIC ${_common_compile_options}) # Build a library for _optimized_kernels_srcs diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index 4d7b8efe9e3..c071e63f0dc 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -9,35 +9,13 @@ #include #include #include -#include -#include -#include +#include #include namespace torch { namespace executor { namespace native { -namespace { - -ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { - ET_CHECK( - !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); - ET_CHECK( - !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type)); - - if (isFloatingType(a_type) && isFloatingType(b_type)) { - return promoteTypes(a_type, b_type); - } else if (isFloatingType(a_type)) { - return a_type; - } else if (isFloatingType(b_type)) { - return b_type; - } - return ScalarType::Float; -} - -} // namespace - Tensor& opt_div_out( KernelRuntimeContext& ctx, const Tensor& a, @@ -163,34 +141,7 @@ Tensor& opt_div_out( } }); } else { - ScalarType common_type = get_compute_type(a_type, b_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { - ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted / b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - }); + div_out_impl(ctx, a, b, out); } return out; @@ -232,32 +183,7 @@ Tensor& opt_div_scalar_out( }); }); } else { - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted; - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * - inv_b_casted); - } - }); - }); - }); - }); + div_scalar_out_impl(ctx, a, b, out); } return out; diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 77a270cc45d..dc186d70a8b 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -20,8 +20,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_div", deps = [ ":binary_ops", - "//executorch/kernels/portable/cpu:scalar_utils", - "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu:op_div_impl", ], ), op_target(name = "op_exp"),