diff --git a/kernels/portable/cpu/op_prod.cpp b/kernels/portable/cpu/op_prod.cpp index 9580dee2d12..a1b9f720349 100644 --- a/kernels/portable/cpu/op_prod.cpp +++ b/kernels/portable/cpu/op_prod.cpp @@ -33,8 +33,8 @@ Tensor& prod_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "prod.int_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { const auto data_in = in.const_data_ptr(); auto data_out = out.mutable_data_ptr(); data_out[0] = static_cast(1); @@ -73,8 +73,8 @@ Tensor& prod_int_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "prod.int_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { CTYPE_OUT prod = 1; diff --git a/kernels/test/op_prod_test.cpp b/kernels/test/op_prod_test.cpp index f96eea9564c..a774bc564c6 100644 --- a/kernels/test/op_prod_test.cpp +++ b/kernels/test/op_prod_test.cpp @@ -45,6 +45,24 @@ class OpProdOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } + + template + void test_dtype() { + TensorFactory tf; + TensorFactory< + executorch::runtime::isIntegralType(DTYPE, /*includeBool*/ true) + ? ScalarType::Long + : DTYPE> + tf_out; + + Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + optional dtype{}; + Tensor out = tf_out.zeros({}); + Tensor out_expected = + tf_out.make({}, {DTYPE == ScalarType::Bool ? 1 : 720}); + op_prod_out(self, dtype, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } }; class OpProdIntOutTest : public ::testing::Test { @@ -54,30 +72,32 @@ class OpProdIntOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } -}; -TEST_F(OpProdOutTest, SmokeTest) { - TensorFactory tfFloat; + template + void test_dtype() { + TensorFactory tf; - Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6}); - optional dtype{}; - Tensor out = tfFloat.zeros({}); - Tensor out_expected = tfFloat.make({}, {720}); - op_prod_out(self, dtype, out); - EXPECT_TENSOR_CLOSE(out, out_expected); -} + Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + int64_t dim = 0; + bool keepdim = false; + optional dtype{}; + Tensor out = tf.zeros({3}); + Tensor out_expected = tf.make({3}, {4, 10, 18}); + op_prod_int_out(self, dim, keepdim, dtype, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } +}; -TEST_F(OpProdIntOutTest, SmokeTest) { - TensorFactory tfFloat; +TEST_F(OpProdOutTest, SmokeTest){ +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY +} - Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6}); - int64_t dim = 0; - bool keepdim = false; - optional dtype{}; - Tensor out = tfFloat.zeros({3}); - Tensor out_expected = tfFloat.make({3}, {4, 10, 18}); - op_prod_int_out(self, dim, keepdim, dtype, out); - EXPECT_TENSOR_CLOSE(out, out_expected); +TEST_F(OpProdIntOutTest, SmokeTest){ +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY } TEST_F(OpProdIntOutTest, SmokeTestKeepdim) { diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 9ccda151283..d914cc58c35 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -279,7 +279,10 @@ class TensorFactory { t = empty_strided(sizes, strides); } if (t.nbytes() > 0) { - memcpy(t.template data(), data.data(), t.nbytes()); + std::transform( + data.begin(), data.end(), t.template data(), [](auto x) { + return static_cast(x); + }); } return t; } @@ -319,7 +322,10 @@ class TensorFactory { t = empty_strided(sizes, strides); } if (t.nbytes() > 0) { - memcpy(t.template data(), data.data(), t.nbytes()); + std::transform( + data.begin(), data.end(), t.template data(), [](auto x) { + return static_cast(x); + }); } return t; } @@ -721,6 +727,13 @@ class TensorFactory { */ using ctype = typename internal::ScalarTypeToCppTypeWrapper::ctype; + /** + * The official C type for the scalar type. Used when accessing elements + * of a constructed Tensor. + */ + using true_ctype = + typename executorch::runtime::ScalarTypeToCppType::type; + TensorFactory() = default; /** @@ -1019,7 +1032,14 @@ class TensorFactory { data_.data(), dim_order_.data(), strides_.data(), - dynamism) {} + dynamism) { + // The only valid values for bool are 0 and 1; coerce! + if constexpr (std::is_same_v) { + for (auto& x : data_) { + x = static_cast(x); + } + } + } std::vector sizes_; std::vector data_;