Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,8 @@ inline void dtype_specialized_elementwise_fn_impl(
});
}

template <typename CTYPE_COMPUTE, typename Op, typename... Args>
template <typename CTYPE_COMPUTE, typename... Args>
inline bool validate_elementwise_fn_inputs(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Expand Down Expand Up @@ -288,7 +287,7 @@ inline void apply_elementwise_fn_runtime_out_dtypes(
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}
Expand All @@ -313,18 +312,19 @@ inline void apply_elementwise_fn(
const Tensor& out,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
if constexpr (should_include_kernel_dtype(op_name, compute_type)) {
constexpr ScalarType out_specialized_scalar_type =
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
if constexpr (should_include_kernel_dtype(
op_name, out_specialized_scalar_type)) {
const bool all_inputs_compute_dtype =
((inputs.first->scalar_type() == compute_type) && ...);

constexpr ScalarType out_specialized_scalar_type =
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
if (all_inputs_compute_dtype &&
out.scalar_type() == out_specialized_scalar_type) {
using CTYPE_OUT =
Expand Down
Loading