diff --git a/kernels/portable/cpu/op_logical_not.cpp b/kernels/portable/cpu/op_logical_not.cpp index cf10f572e43..a67edf5cad5 100644 --- a/kernels/portable/cpu/op_logical_not.cpp +++ b/kernels/portable/cpu/op_logical_not.cpp @@ -33,10 +33,10 @@ logical_not_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK(ctx, tensors_have_same_shape(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] { - ET_SWITCH_REAL_TYPES_AND( - Bool, out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES( + in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES( + out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] { apply_unary_map_fn( [](const CTYPE_IN val_in) { return static_cast(!static_cast(val_in)); diff --git a/kernels/test/op_logical_not_test.cpp b/kernels/test/op_logical_not_test.cpp index c8173598605..1b23b93b03c 100644 --- a/kernels/test/op_logical_not_test.cpp +++ b/kernels/test/op_logical_not_test.cpp @@ -122,9 +122,9 @@ TEST_F(OpLogicalNotOutTest, AllTypePasses) { test_logical_not_out(); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_REALHBBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL } diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 668c3c1cacb..31037843e50 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -252,6 +252,10 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16) +#define ET_FORALL_REALHBBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT2, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, bool, Bool) + // For macros that take `SCALARTYPEn` parameters, those parameters should be // an unquoted/unqualified enumerator name like `Int` or `Float`. #define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \