From 71d32cc6c3cc371d3c2596c96b61cb8136824e48 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 2 Jul 2025 16:36:31 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/elementwise_util.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 0a2d4bfc89a..f195afd60dc 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -319,12 +319,13 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - if constexpr (should_include_kernel_dtype(op_name, compute_type)) { + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if constexpr (should_include_kernel_dtype( + op_name, out_specialized_scalar_type)) { const bool all_inputs_compute_dtype = ((inputs.first->scalar_type() == compute_type) && ...); - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); if (all_inputs_compute_dtype && out.scalar_type() == out_specialized_scalar_type) { using CTYPE_OUT = From dc414617ce956aafbb69b047c4872e70b3c0382c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 3 Jul 2025 15:00:43 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/elementwise_util.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index f195afd60dc..5bb5becf185 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -193,9 +193,8 @@ inline void dtype_specialized_elementwise_fn_impl( }); } -template +template inline bool validate_elementwise_fn_inputs( - const Op& compute_fun, KernelRuntimeContext& ctx, const Tensor& out, SupportedTensorDtypes out_dtypes, @@ -288,7 +287,7 @@ inline void apply_elementwise_fn_runtime_out_dtypes( SupportedTensorDtypes out_dtypes, Args... inputs) { const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + ctx, out, out_dtypes, inputs...); if (!inputs_valid) { return; } @@ -313,7 +312,7 @@ inline void apply_elementwise_fn( const Tensor& out, Args... inputs) { const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + ctx, out, out_dtypes, inputs...); if (!inputs_valid) { return; }