diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..a376a89747b 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -345,20 +346,22 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - const bool all_inputs_compute_dtype = - ((inputs.first->scalar_type() == compute_type) && ...); - - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); - if (all_inputs_compute_dtype && - out.scalar_type() == out_specialized_scalar_type) { - using CTYPE_OUT = - typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); - return; + if constexpr (should_include_kernel_dtype(op_name, compute_type)) { + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + return; + } } apply_elementwise_fn_generic_impl< diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 65a0c9fc47a..8c6633720e7 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -114,10 +114,10 @@ def define_common_targets(): ":vectorized_math", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],