diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index d240b9f83bc..525199a6f78 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -27,6 +27,8 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); + case SupportedTensorDtypes::BOOL: + return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 1e7901c80b2..15732219c8f 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -72,6 +72,16 @@ load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { return result; } +template +load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::load_and_convert; +} + template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { @@ -165,6 +175,17 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( return result; } +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( + const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::convert_and_store; +} + template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -219,6 +240,7 @@ enum class SupportedTensorDtypes { REALHBF16, FLOATHBF16, INTB, + BOOL, BOOL_OR_BYTE, // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, @@ -240,6 +262,8 @@ load_to_compute_fn get_load_to_compute_fn_impl( return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_compute_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_load_to_compute_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -271,6 +295,8 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( t); case SupportedTensorDtypes::INTB: return get_store_compute_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_store_compute_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, @@ -318,12 +344,14 @@ bool check_tensor_dtype( const ScalarType compute_type); /// Return the one output type we are willing to emit specialized code -/// to handle, given a compute type of CTYPE_COMMON and supported +/// to handle, given a compute type of CTYPE_COMPUTE and supported /// output types of out_dtypes. template inline constexpr ScalarType specialized_output_scalar_type( SupportedTensorDtypes out_dtypes) { switch (out_dtypes) { + case SupportedTensorDtypes::BOOL: + return ScalarType::Bool; case SupportedTensorDtypes::BOOL_OR_BYTE: return ScalarType::Bool; case SupportedTensorDtypes::REALHBBF16: diff --git a/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h b/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h index 6e49dd9e57b..d1e812ec2c2 100644 --- a/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h +++ b/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h @@ -72,20 +72,16 @@ class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest { auto expected = tf_out.make({1, 6}, expected_vector); if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) { - double rtol = executorch::runtime::testing::internal::kDefaultRtol; - // It appears we need a higher tolerance for at least some ATen - // tests, like aten_op_acosh_test. - if (get_supported_features()->is_aten) { - rtol = 3e-3; - } + // Raise tolerance because both we and ATen run these + // computations at internal float32 precision rather than + // float64. + double rtol = 3e-3; EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol); } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) { - double rtol = executorch::runtime::testing::internal::kDefaultRtol; - // It appears we need a higher tolerance for at least some ATen - // tests, like aten_op_acosh_test. - if (get_supported_features()->is_aten) { - rtol = 1e-3; - } + // Raise tolerance because both we and ATen run these + // computations at internal float32 precision rather than + // float64. + double rtol = 1e-3; EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol); } else { EXPECT_TENSOR_CLOSE(out, expected); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index e109193e227..04961d6e0b0 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -742,6 +742,21 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } +// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8), +// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1], +// dtype=torch.long)) tensor([16]) +TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) { + TensorFactory tf_in; + TensorFactory tf_out; + + Tensor in = tf_in.make({1}, {100}); + Tensor out = tf_out.zeros({1}); + Tensor ret = op_mul_out(in, in, out); + + Tensor expected = tf_out.make({1}, {16}); + EXPECT_TENSOR_CLOSE(out, expected); +} + TEST_F(OpMulScalarOutTest, SanityCheck) { TensorFactory tf_a; TensorFactory tf_out;