Skip to content

Commit fb4940e

Browse files
committed
Add unit tests for convert_to_bfloat
1 parent be5d187 commit fb4940e

File tree

5 files changed

+134
-33
lines changed

5 files changed

+134
-33
lines changed

backends/aoti/common_shims.cpp

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,36 @@ AOTITorchError aoti_torch_get_storage_offset(
5050
}
5151

5252
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
53-
std::vector<int64_t> strides(tensor->dim());
54-
auto tensor_strides = tensor->strides();
55-
for (ssize_t i = 0; i < tensor->dim(); i++) {
56-
strides[i] = static_cast<int64_t>(tensor_strides[i]);
53+
auto it = internal::tensor_to_strides.find(tensor);
54+
bool needs_update = false;
55+
56+
if (it == internal::tensor_to_strides.end()) {
57+
needs_update = true;
58+
} else {
59+
// Check if cached values are still valid
60+
auto tensor_strides = tensor->strides();
61+
if (it->second.size() != static_cast<size_t>(tensor->dim())) {
62+
needs_update = true;
63+
} else {
64+
for (int i = 0; i < tensor->dim(); i++) {
65+
if (it->second[i] != tensor_strides[i]) {
66+
needs_update = true;
67+
break;
68+
}
69+
}
70+
}
71+
}
72+
73+
if (needs_update) {
74+
std::vector<int64_t> strides(tensor->dim());
75+
auto tensor_strides = tensor->strides();
76+
for (int i = 0; i < tensor->dim(); i++) {
77+
strides[i] = tensor_strides[i];
78+
}
79+
it =
80+
internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides))
81+
.first;
5782
}
58-
auto it =
59-
internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides))
60-
.first;
6183

6284
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
6385
// return a valid pointer
@@ -78,13 +100,35 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
78100
}
79101

80102
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
81-
std::vector<int64_t> sizes(tensor->dim());
82-
auto tensor_sizes = tensor->sizes();
83-
for (ssize_t i = 0; i < tensor->dim(); i++) {
84-
sizes[i] = static_cast<int64_t>(tensor_sizes[i]);
103+
auto it = internal::tensor_to_sizes.find(tensor);
104+
bool needs_update = false;
105+
106+
if (it == internal::tensor_to_sizes.end()) {
107+
needs_update = true;
108+
} else {
109+
// Check if cached values are still valid
110+
auto tensor_sizes = tensor->sizes();
111+
if (it->second.size() != static_cast<size_t>(tensor->dim())) {
112+
needs_update = true;
113+
} else {
114+
for (int i = 0; i < tensor->dim(); i++) {
115+
if (it->second[i] != tensor_sizes[i]) {
116+
needs_update = true;
117+
break;
118+
}
119+
}
120+
}
121+
}
122+
123+
if (needs_update) {
124+
std::vector<int64_t> sizes(tensor->dim());
125+
auto tensor_sizes = tensor->sizes();
126+
for (int i = 0; i < tensor->dim(); i++) {
127+
sizes[i] = tensor_sizes[i];
128+
}
129+
it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes))
130+
.first;
85131
}
86-
auto it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes))
87-
.first;
88132

89133
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
90134
// return a valid pointer

extension/llm/runner/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
1919

2020
set(_test_srcs
2121
test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp
22-
test_text_decoder_runner.cpp test_multimodal_input.cpp test_wav_loader.cpp
22+
test_text_decoder_runner.cpp test_multimodal_input.cpp test_util.cpp
23+
test_wav_loader.cpp
2324
)
2425

2526
# Add LSan stub for Apple platforms

extension/llm/runner/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ def define_common_targets():
4545
],
4646
)
4747

48+
runtime.cxx_test(
49+
name = "test_util",
50+
srcs = ["test_util.cpp"],
51+
deps = [
52+
"//executorch/extension/llm/runner:stats",
53+
"//executorch/extension/tensor:tensor",
54+
"//executorch/runtime/core:core",
55+
],
56+
)
57+
4858
runtime.cxx_test(
4959
name = "test_wav_loader",
5060
srcs = ["test_wav_loader.cpp"],
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/runner/util.h>
10+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
11+
12+
#include <gtest/gtest.h>
13+
14+
#include <vector>
15+
16+
namespace {
17+
18+
using ::executorch::aten::ScalarType;
19+
using ::executorch::extension::make_tensor_ptr;
20+
using ::executorch::extension::llm::convert_to_bfloat16;
21+
22+
TEST(ConvertToBFloat16Test, ConvertsFloatTensorData) {
23+
auto source_tensor = make_tensor_ptr<float>(
24+
{2, 2}, std::vector<float>{0.0f, 1.5f, -2.0f, 3.25f});
25+
26+
auto result = convert_to_bfloat16(source_tensor);
27+
ASSERT_TRUE(result.ok());
28+
auto bf16_tensor = *result;
29+
30+
EXPECT_EQ(bf16_tensor->scalar_type(), ScalarType::BFloat16);
31+
EXPECT_EQ(bf16_tensor->numel(), source_tensor->numel());
32+
33+
auto src_sizes = source_tensor->sizes();
34+
auto dst_sizes = bf16_tensor->sizes();
35+
ASSERT_EQ(dst_sizes.size(), src_sizes.size());
36+
for (size_t dim = 0; dim < dst_sizes.size(); ++dim) {
37+
EXPECT_EQ(dst_sizes[dim], src_sizes[dim]);
38+
}
39+
40+
const auto* converted_data = bf16_tensor->const_data_ptr<::c10::BFloat16>();
41+
const auto* original_data = source_tensor->const_data_ptr<float>();
42+
ASSERT_NE(converted_data, nullptr);
43+
ASSERT_NE(original_data, nullptr);
44+
45+
for (size_t i = 0; i < static_cast<size_t>(source_tensor->numel()); ++i) {
46+
EXPECT_NEAR(static_cast<float>(converted_data[i]), original_data[i], 1e-2f);
47+
}
48+
}
49+
50+
TEST(ConvertToBFloat16Test, RejectsNonFloatTensor) {
51+
auto non_float_tensor =
52+
make_tensor_ptr<int64_t>({3}, std::vector<int64_t>{1, 2, 3});
53+
54+
auto result = convert_to_bfloat16(non_float_tensor);
55+
EXPECT_FALSE(result.ok());
56+
EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument);
57+
}
58+
59+
} // namespace

extension/llm/runner/util.h

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,30 +152,17 @@ convert_to_bfloat16(const ::executorch::extension::TensorPtr& src_tensor) {
152152
InvalidArgument,
153153
"BFloat16 conversion only supported from Float source data");
154154

155-
size_t num_elements = src_tensor->numel();
156-
auto sizes = src_tensor->sizes();
157-
158-
// Allocate memory for bfloat16 data
159-
auto* bf16_data = new uint16_t[num_elements];
155+
const auto num_elements = static_cast<size_t>(src_tensor->numel());
160156
const float* float_data = src_tensor->const_data_ptr<float>();
161157

162-
// Convert float to bfloat16
158+
auto bf16_tensor = ::executorch::extension::empty_like(
159+
src_tensor, ::executorch::aten::ScalarType::BFloat16);
160+
auto* bf16_data = bf16_tensor->mutable_data_ptr<::c10::BFloat16>();
163161
for (size_t i = 0; i < num_elements; ++i) {
164-
// bfloat16 is the upper 16 bits of float32
165-
uint32_t float_bits;
166-
std::memcpy(&float_bits, &float_data[i], sizeof(float));
167-
168-
// Rounding: add 0x7FFF to round to nearest even
169-
uint32_t rounding_bias = 0x7FFF + ((float_bits >> 16) & 1);
170-
bf16_data[i] = static_cast<uint16_t>((float_bits + rounding_bias) >> 16);
162+
bf16_data[i] = ::c10::BFloat16(float_data[i]);
171163
}
172164

173-
// Create tensor with deleter to free allocated memory
174-
return ::executorch::extension::from_blob(
175-
bf16_data,
176-
{sizes.begin(), sizes.end()},
177-
::executorch::aten::ScalarType::BFloat16,
178-
[](void* ptr) { delete[] static_cast<uint16_t*>(ptr); });
165+
return bf16_tensor;
179166
}
180167

181168
} // namespace llm

0 commit comments

Comments
 (0)