diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp index 7809f04a1dd..ce9de03444c 100644 --- a/extension/tensor/tensor_ptr_maker.cpp +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -34,8 +34,11 @@ bool extract_scalar(executorch::aten::Scalar scalar, INT_T* out_val) { template < typename FLOAT_T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if< + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, + bool>::type = true> bool extract_scalar(executorch::aten::Scalar scalar, FLOAT_T* out_val) { double val; if (scalar.isFloatingPoint()) { @@ -59,7 +62,7 @@ template < typename std::enable_if::value, bool>::type = true> bool extract_scalar(executorch::aten::Scalar scalar, BOOL_T* out_val) { - if (scalar.isIntegral(false)) { + if (scalar.isIntegral(/*includeBool=*/false)) { *out_val = static_cast(scalar.to()); return true; } @@ -86,7 +89,7 @@ TensorPtr random_strided( empty_strided(std::move(sizes), std::move(strides), type, dynamism); std::default_random_engine gen{std::random_device{}()}; - ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] { + ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] { std::generate_n(tensor->mutable_data_ptr(), tensor->numel(), [&]() { return static_cast(distribution(gen)); }); @@ -121,7 +124,7 @@ TensorPtr full_strided( executorch::aten::TensorShapeDynamism dynamism) { auto tensor = empty_strided(std::move(sizes), std::move(strides), type, dynamism); - ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] { + ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] { CTYPE value; ET_EXTRACT_SCALAR(fill_value, value); std::fill( diff --git a/extension/tensor/test/tensor_ptr_maker_test.cpp b/extension/tensor/test/tensor_ptr_maker_test.cpp index 63e49224b04..e17d18229df 100644 --- a/extension/tensor/test/tensor_ptr_maker_test.cpp +++ b/extension/tensor/test/tensor_ptr_maker_test.cpp @@ -234,6 +234,20 @@ TEST_F(TensorPtrMakerTest, CreateFull) { EXPECT_EQ(tensor4->size(1), 5); EXPECT_EQ(tensor4->scalar_type(), executorch::aten::ScalarType::Double); EXPECT_EQ(tensor4->const_data_ptr()[0], 11); + + auto tensor5 = full({4, 5}, 13, executorch::aten::ScalarType::Half); + EXPECT_EQ(tensor5->dim(), 2); + EXPECT_EQ(tensor5->size(0), 4); + EXPECT_EQ(tensor5->size(1), 5); + EXPECT_EQ(tensor5->scalar_type(), executorch::aten::ScalarType::Half); + EXPECT_EQ(tensor5->const_data_ptr()[0], 13); + + auto tensor6 = full({4, 5}, 15, executorch::aten::ScalarType::BFloat16); + EXPECT_EQ(tensor6->dim(), 2); + EXPECT_EQ(tensor6->size(0), 4); + EXPECT_EQ(tensor6->size(1), 5); + EXPECT_EQ(tensor6->scalar_type(), executorch::aten::ScalarType::BFloat16); + EXPECT_EQ(tensor6->const_data_ptr()[0], 15); } TEST_F(TensorPtrMakerTest, CreateScalar) { @@ -363,6 +377,36 @@ TEST_F(TensorPtrMakerTest, CreateRandTensorWithDoubleType) { } } +TEST_F(TensorPtrMakerTest, CreateRandTensorWithHalfType) { + auto tensor = rand({4, 5}, executorch::aten::ScalarType::Half); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Half); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 0.0); + EXPECT_LT(val, 1.0); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandTensorWithBFloatType) { + auto tensor = rand({4, 5}, executorch::aten::ScalarType::BFloat16); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::BFloat16); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 0.0); + EXPECT_LT(val, 1.0); + } +} + TEST_F(TensorPtrMakerTest, CreateRandnTensor) { auto tensor = randn({100, 100});