Skip to content

Commit 8587df7

Browse files
lucylqfacebook-github-bot
authored andcommitted
flat_tensor check size (pytorch#14188)
Summary: Add size check similar to D81938296, for flat tensor Differential Revision: D82168471
1 parent f294074 commit 8587df7

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
7373
segments->size());
7474
// Validate the segment.
7575
ET_CHECK_OR_RETURN_ERROR(
76-
segments->Get(segment_index)->offset() < segment_end_offset,
76+
(segments->Get(segment_index)->offset() +
77+
segments->Get(segment_index)->size()) < segment_end_offset,
7778
InvalidExternalData,
7879
"Invalid segment offset %" PRIu64
7980
" is larger than the segment_base_offset + segment_data_size %" PRIu64
@@ -206,15 +207,21 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
206207
}
207208
Result<FlatTensorHeader> fh =
208209
FlatTensorHeader::Parse(header->data(), header->size());
209-
if (fh.error() == Error::NotFound) {
210-
// No header, throw error.
211-
ET_LOG(Error, "No FlatTensorHeader found.");
212-
return fh.error();
213-
} else if (fh.error() != Error::Ok) {
214-
// corruption, throw error.
215-
ET_LOG(Error, "Flat tensor header may be corrupt.");
216-
return fh.error();
217-
}
210+
211+
ET_CHECK_OR_RETURN_ERROR(
212+
fh.ok(),
213+
InvalidExternalData,
214+
"Failed to parse FlatTensor header with error code %u. File may be corrupt.",
215+
static_cast<uint32_t>(fh.error()));
216+
217+
size_t expected_size = fh->segment_base_offset + fh->segment_data_size;
218+
size_t actual_size = loader->size().get();
219+
ET_CHECK_OR_RETURN_ERROR(
220+
expected_size == actual_size,
221+
InvalidExternalData,
222+
"File size is too small; file may be corrupted or truncated. Expected %zu from flat_tensor header, received %zu from data loader",
223+
expected_size,
224+
actual_size);
218225

219226
// Load flatbuffer data as a segment.
220227
Result<FreeableBuffer> flat_tensor_data = loader->load(

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/extension/data_loader/buffer_data_loader.h>
910
#include <executorch/extension/data_loader/file_data_loader.h>
1011
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1112
#include <executorch/extension/flat_tensor/serialize/flat_tensor_generated.h>
@@ -17,14 +18,15 @@
1718
#include <gtest/gtest.h>
1819

1920
using namespace ::testing;
21+
using executorch::extension::BufferDataLoader;
22+
using executorch::extension::FileDataLoader;
2023
using executorch::extension::FlatTensorDataMap;
2124
using executorch::extension::FlatTensorHeader;
2225
using executorch::runtime::DataLoader;
2326
using executorch::runtime::Error;
2427
using executorch::runtime::FreeableBuffer;
2528
using executorch::runtime::Result;
2629
using executorch::runtime::TensorLayout;
27-
using torch::executor::util::FileDataLoader;
2830

2931
class FlatTensorDataMapTest : public ::testing::Test {
3032
protected:
@@ -51,7 +53,7 @@ TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
5153
EXPECT_EQ(data_map.error(), Error::Ok);
5254
}
5355

54-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
56+
TEST_F(FlatTensorDataMapTest, GetMetadata) {
5557
Result<FlatTensorDataMap> data_map =
5658
FlatTensorDataMap::load(data_map_loader_.get());
5759
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -93,7 +95,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
9395
EXPECT_EQ(const_c_res.error(), Error::NotFound);
9496
}
9597

96-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
98+
TEST_F(FlatTensorDataMapTest, GetData) {
9799
Result<FlatTensorDataMap> data_map =
98100
FlatTensorDataMap::load(data_map_loader_.get());
99101
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -114,7 +116,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
114116
EXPECT_EQ(data_c_res.error(), Error::NotFound);
115117
}
116118

117-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
119+
TEST_F(FlatTensorDataMapTest, GetKeys) {
118120
Result<FlatTensorDataMap> data_map =
119121
FlatTensorDataMap::load(data_map_loader_.get());
120122
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -138,7 +140,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
138140
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
139141
}
140142

141-
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
143+
TEST_F(FlatTensorDataMapTest, LoadInto) {
142144
Result<FlatTensorDataMap> data_map =
143145
FlatTensorDataMap::load(data_map_loader_.get());
144146
EXPECT_EQ(data_map.error(), Error::Ok);
@@ -160,3 +162,23 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160162
}
161163
free(data);
162164
}
165+
166+
TEST_F(FlatTensorDataMapTest, LoadAndCheckSize) {
167+
Result<FlatTensorDataMap> data_map =
168+
FlatTensorDataMap::load(data_map_loader_.get());
169+
EXPECT_EQ(data_map.error(), Error::Ok);
170+
171+
// Truncate the file.
172+
size_t trunc_size = data_map_loader_->size().get() - 8;
173+
Result<FreeableBuffer> truncated_file = data_map_loader_->load(
174+
0,
175+
trunc_size,
176+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
177+
ASSERT_EQ(truncated_file.error(), Error::Ok);
178+
179+
BufferDataLoader truncated_loader =
180+
BufferDataLoader(truncated_file->data(), trunc_size);
181+
Result<FlatTensorDataMap> truncated_program =
182+
FlatTensorDataMap::load(&truncated_loader);
183+
ASSERT_EQ(truncated_program.error(), Error::InvalidExternalData);
184+
}

extension/flat_tensor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def define_common_targets(is_fbcode=False):
4646
],
4747
deps = [
4848
"//executorch/extension/data_loader:file_data_loader",
49+
"//executorch/extension/data_loader:buffer_data_loader",
4950
"//executorch/extension/flat_tensor:flat_tensor_data_map",
5051
"//executorch/runtime/core:named_data_map",
5152
"//executorch/runtime/core/exec_aten:lib",

0 commit comments

Comments
 (0)