Skip to content

Commit 9a9db14

Browse files
authored
flat_tensor check size
Differential Revision: D82168471 Pull Request resolved: #14188
1 parent b63b358 commit 9a9db14

File tree

5 files changed

+55
-18
lines changed

5 files changed

+55
-18
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
7474
segments->size());
7575
// Validate the segment.
7676
ET_CHECK_OR_RETURN_ERROR(
77-
segments->Get(segment_index)->offset() < segment_end_offset,
77+
(segments->Get(segment_index)->offset() +
78+
segments->Get(segment_index)->size()) <= segment_end_offset,
7879
InvalidExternalData,
7980
"Invalid segment offset %" PRIu64
8081
" is larger than the segment_base_offset + segment_data_size %" PRIu64
@@ -207,15 +208,21 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
207208
}
208209
Result<FlatTensorHeader> fh =
209210
FlatTensorHeader::Parse(header->data(), header->size());
210-
if (fh.error() == Error::NotFound) {
211-
// No header, throw error.
212-
ET_LOG(Error, "No FlatTensorHeader found.");
213-
return fh.error();
214-
} else if (fh.error() != Error::Ok) {
215-
// corruption, throw error.
216-
ET_LOG(Error, "Flat tensor header may be corrupt.");
217-
return fh.error();
218-
}
211+
212+
ET_CHECK_OR_RETURN_ERROR(
213+
fh.ok(),
214+
InvalidExternalData,
215+
"Failed to parse FlatTensor header with error code %u. File may be corrupt.",
216+
static_cast<uint32_t>(fh.error()));
217+
218+
size_t expected_size = fh->segment_base_offset + fh->segment_data_size;
219+
size_t actual_size = loader->size().get();
220+
ET_CHECK_OR_RETURN_ERROR(
221+
expected_size == actual_size,
222+
InvalidExternalData,
223+
"File size is too small; file may be corrupted or truncated. Expected %zu from flat_tensor header, received %zu from data loader",
224+
expected_size,
225+
actual_size);
219226

220227
// Load flatbuffer data as a segment.
221228
Result<FreeableBuffer> flat_tensor_data = loader->load(

extension/flat_tensor/serialize/serialize.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ runtime::Error save_ptd(
7676
// Write the tensors.
7777
size_t total_segment_size = 0;
7878
uint32_t i = 0;
79+
size_t tensor_count = tensor_map.size();
7980
for (const auto& [name, tensor] : tensor_map) {
8081
auto key = builder.CreateString(name);
8182
// Write the tensor layouts.
@@ -99,7 +100,11 @@ runtime::Error save_ptd(
99100
/*_fbb=*/builder,
100101
/*offset=*/total_segment_size,
101102
/*size=*/tensor.nbytes()));
102-
total_segment_size += aligned_size(tensor.nbytes(), tensor_alignment);
103+
104+
// Do not pad the last tensor.
105+
total_segment_size += (i == tensor_count - 1)
106+
? tensor.nbytes()
107+
: aligned_size(tensor.nbytes(), tensor_alignment);
103108
i++;
104109
}
105110

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/extension/data_loader/buffer_data_loader.h>
910
#include <executorch/extension/data_loader/file_data_loader.h>
1011
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1112
#include <executorch/extension/flat_tensor/serialize/flat_tensor_generated.h>
@@ -17,14 +18,15 @@
1718
#include <gtest/gtest.h>
1819

1920
using namespace ::testing;
21+
using executorch::extension::BufferDataLoader;
22+
using executorch::extension::FileDataLoader;
2023
using executorch::extension::FlatTensorDataMap;
2124
using executorch::extension::FlatTensorHeader;
2225
using executorch::runtime::DataLoader;
2326
using executorch::runtime::Error;
2427
using executorch::runtime::FreeableBuffer;
2528
using executorch::runtime::Result;
2629
using executorch::runtime::TensorLayout;
27-
using torch::executor::util::FileDataLoader;
2830

2931
class FlatTensorDataMapTest : public ::testing::Test {
3032
protected:
@@ -51,7 +53,7 @@ TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
5153
EXPECT_EQ(data_map.error(), Error::Ok);
5254
}
5355

54-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
56+
TEST_F(FlatTensorDataMapTest, GetMetadata) {
5557
Result<FlatTensorDataMap> data_map =
5658
FlatTensorDataMap::load(data_map_loader_.get());
5759
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -93,7 +95,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
9395
EXPECT_EQ(const_c_res.error(), Error::NotFound);
9496
}
9597

96-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
98+
TEST_F(FlatTensorDataMapTest, GetData) {
9799
Result<FlatTensorDataMap> data_map =
98100
FlatTensorDataMap::load(data_map_loader_.get());
99101
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -114,7 +116,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
114116
EXPECT_EQ(data_c_res.error(), Error::NotFound);
115117
}
116118

117-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
119+
TEST_F(FlatTensorDataMapTest, GetKeys) {
118120
Result<FlatTensorDataMap> data_map =
119121
FlatTensorDataMap::load(data_map_loader_.get());
120122
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -138,7 +140,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
138140
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
139141
}
140142

141-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
143+
TEST_F(FlatTensorDataMapTest, LoadInto) {
142144
Result<FlatTensorDataMap> data_map =
143145
FlatTensorDataMap::load(data_map_loader_.get());
144146
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -160,3 +162,23 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160162
}
161163
free(data);
162164
}
165+
166+
TEST_F(FlatTensorDataMapTest, LoadAndCheckSize) {
167+
Result<FlatTensorDataMap> data_map =
168+
FlatTensorDataMap::load(data_map_loader_.get());
169+
EXPECT_EQ(data_map.error(), Error::Ok);
170+
171+
// Truncate the file.
172+
size_t trunc_size = data_map_loader_->size().get() - 8;
173+
Result<FreeableBuffer> truncated_file = data_map_loader_->load(
174+
0,
175+
trunc_size,
176+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
177+
ASSERT_EQ(truncated_file.error(), Error::Ok);
178+
179+
BufferDataLoader truncated_loader =
180+
BufferDataLoader(truncated_file->data(), trunc_size);
181+
Result<FlatTensorDataMap> truncated_program =
182+
FlatTensorDataMap::load(&truncated_loader);
183+
ASSERT_EQ(truncated_program.error(), Error::InvalidExternalData);
184+
}

extension/flat_tensor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def define_common_targets(is_fbcode=False):
4545
"flat_tensor_data_map_test.cpp",
4646
],
4747
deps = [
48+
"//executorch/extension/data_loader:buffer_data_loader",
4849
"//executorch/extension/data_loader:file_data_loader",
4950
"//executorch/extension/flat_tensor:flat_tensor_data_map",
5051
"//executorch/runtime/core:named_data_map",

extension/flat_tensor/test/test_serialize.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) {
8686
const uint64_t segment_offset = 48 + 280 + 8; // 8 is padding.
8787
EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset);
8888

89-
// Segment total size, 8 bytes of data (2 floats), 24 bytes of padding.
90-
const uint64_t segment_size = 32;
89+
// Segment total size = 20
90+
// linear.bias: 4 bytes + 12 bytes of padding.
91+
// linear.weight: 4 bytes + 0 padding (last segment).
92+
const uint64_t segment_size = 20;
9193
EXPECT_EQ(*(uint64_t*)(header_buffer + 32), segment_size);
9294

9395
// Check Flatbuffer

0 commit comments

Comments
 (0)