Skip to content

Commit b2c78c7

Browse files
committed
Respect selective build for dtype_specialized_elementwise_fn_impl in elementwise_util
This fancy fast path I added didn't respect selective build. Now it should. ghstack-source-id: 69988ef ghstack-comment-id: 3005605915 Pull-Request-resolved: #11975
1 parent 3d437c3 commit b2c78c7

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/selective_build.h>
1213
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
@@ -345,20 +346,22 @@ inline void apply_elementwise_fn(
345346
}
346347

347348
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
348-
const bool all_inputs_compute_dtype =
349-
((inputs.first->scalar_type() == compute_type) && ...);
350-
351-
constexpr ScalarType out_specialized_scalar_type =
352-
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
353-
if (all_inputs_compute_dtype &&
354-
out.scalar_type() == out_specialized_scalar_type) {
355-
using CTYPE_OUT =
356-
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
357-
dtype_specialized_elementwise_fn_impl<
358-
CTYPE_COMPUTE,
359-
CTYPE_OUT,
360-
support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...);
361-
return;
349+
if constexpr (should_include_kernel_dtype(op_name, compute_type)) {
350+
const bool all_inputs_compute_dtype =
351+
((inputs.first->scalar_type() == compute_type) && ...);
352+
353+
constexpr ScalarType out_specialized_scalar_type =
354+
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
355+
if (all_inputs_compute_dtype &&
356+
out.scalar_type() == out_specialized_scalar_type) {
357+
using CTYPE_OUT =
358+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
359+
dtype_specialized_elementwise_fn_impl<
360+
CTYPE_COMPUTE,
361+
CTYPE_OUT,
362+
support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...);
363+
return;
364+
}
362365
}
363366

364367
apply_elementwise_fn_generic_impl<

0 commit comments

Comments
 (0)