diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 0a2d4bfc89a..5bb5becf185 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -193,9 +193,8 @@ inline void dtype_specialized_elementwise_fn_impl( }); } -template +template inline bool validate_elementwise_fn_inputs( - const Op& compute_fun, KernelRuntimeContext& ctx, const Tensor& out, SupportedTensorDtypes out_dtypes, @@ -288,7 +287,7 @@ inline void apply_elementwise_fn_runtime_out_dtypes( SupportedTensorDtypes out_dtypes, Args... inputs) { const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + ctx, out, out_dtypes, inputs...); if (!inputs_valid) { return; } @@ -313,18 +312,19 @@ inline void apply_elementwise_fn( const Tensor& out, Args... inputs) { const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + ctx, out, out_dtypes, inputs...); if (!inputs_valid) { return; } constexpr auto compute_type = CppTypeToScalarType::value; - if constexpr (should_include_kernel_dtype(op_name, compute_type)) { + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if constexpr (should_include_kernel_dtype( + op_name, out_specialized_scalar_type)) { const bool all_inputs_compute_dtype = ((inputs.first->scalar_type() == compute_type) && ...); - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); if (all_inputs_compute_dtype && out.scalar_type() == out_specialized_scalar_type) { using CTYPE_OUT =