From 65ff10b433a2a82e8aa93c7cb51066f45f83a86d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 25 Jun 2025 10:30:59 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/elementwise_util.h | 31 +++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..a376a89747b 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -345,20 +346,22 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - const bool all_inputs_compute_dtype = - ((inputs.first->scalar_type() == compute_type) && ...); - - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); - if (all_inputs_compute_dtype && - out.scalar_type() == out_specialized_scalar_type) { - using CTYPE_OUT = - typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); - return; + if constexpr (should_include_kernel_dtype(op_name, compute_type)) { + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + return; + } } apply_elementwise_fn_generic_impl< From e7fce3dc29257e3e335bea4d4580fceb2d8904b4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 25 Jun 2025 10:55:57 -0700 Subject: [PATCH 2/2] fix buck dep [ghstack-poisoned] --- kernels/portable/cpu/util/targets.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 65a0c9fc47a..8c6633720e7 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -114,10 +114,10 @@ def define_common_targets(): ":vectorized_math", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],