diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index d240b9f83bc..525199a6f78 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -27,6 +27,8 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); + case SupportedTensorDtypes::BOOL: + return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 1e7901c80b2..15732219c8f 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -72,6 +72,16 @@ load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { return result; } +template +load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::load_and_convert; +} + template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { @@ -165,6 +175,17 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( return result; } +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( + const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::convert_and_store; +} + template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -219,6 +240,7 @@ enum class SupportedTensorDtypes { REALHBF16, FLOATHBF16, INTB, + BOOL, BOOL_OR_BYTE, // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, @@ -240,6 +262,8 @@ load_to_compute_fn get_load_to_compute_fn_impl( return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_compute_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_load_to_compute_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -271,6 +295,8 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( t); case SupportedTensorDtypes::INTB: return get_store_compute_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_store_compute_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, @@ -318,12 +344,14 @@ bool check_tensor_dtype( const ScalarType compute_type); /// Return the one output type we are willing to emit specialized code -/// to handle, given a compute type of CTYPE_COMMON and supported +/// to handle, given a compute type of CTYPE_COMPUTE and supported /// output types of out_dtypes. template inline constexpr ScalarType specialized_output_scalar_type( SupportedTensorDtypes out_dtypes) { switch (out_dtypes) { + case SupportedTensorDtypes::BOOL: + return ScalarType::Bool; case SupportedTensorDtypes::BOOL_OR_BYTE: return ScalarType::Bool; case SupportedTensorDtypes::REALHBBF16: diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index c21cceeaae3..34433fbe95c 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -746,6 +746,21 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } +// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8), +// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1], +// dtype=torch.long)) tensor([16]) +TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) { + TensorFactory tf_in; + TensorFactory tf_out; + + Tensor in = tf_in.make({1}, {100}); + Tensor out = tf_out.zeros({1}); + Tensor ret = op_mul_out(in, in, out); + + Tensor expected = tf_out.make({1}, {16}); + EXPECT_TENSOR_CLOSE(out, expected); +} + TEST_F(OpMulScalarOutTest, SanityCheck) { TensorFactory tf_a; TensorFactory tf_out;