diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 0a2d4bfc89a..6adf81f70e3 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -319,22 +318,20 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - if constexpr (should_include_kernel_dtype(op_name, compute_type)) { - const bool all_inputs_compute_dtype = - ((inputs.first->scalar_type() == compute_type) && ...); - - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); - if (all_inputs_compute_dtype && - out.scalar_type() == out_specialized_scalar_type) { - using CTYPE_OUT = - typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); - return; - } + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + return; } apply_elementwise_fn_generic_impl< diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1806ebb0d5a..d158ab136ab 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -115,11 +115,11 @@ def define_common_targets(): ":vectorized_math", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", "//executorch/kernels/portable/cpu:scalar_utils", ], deps = [ + "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],