Skip to content

Commit fea2310

Browse files
committed
Implement load_into for SharedPtrDataLoader and add test (#11562)
1 parent 56392aa commit fea2310

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

extension/data_loader/shared_ptr_data_loader.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/runtime/core/result.h>
1414
#include <executorch/runtime/platform/log.h>
1515
#include <memory>
16+
#include <cstring>
1617

1718
namespace executorch {
1819
namespace extension {
@@ -43,11 +44,35 @@ class SharedPtrDataLoader final : public executorch::runtime::DataLoader {
4344
return executorch::runtime::FreeableBuffer(
4445
static_cast<uint8_t*>(data_.get()) + offset, size, /*free_fn=*/nullptr);
4546
}
47+
48+
ET_NODISCARD executorch::runtime::Error load_into(
49+
size_t offset,
50+
size_t size,
51+
const DataLoader::SegmentInfo& segment_info,
52+
void* buffer) const override;
4653

4754
ET_NODISCARD executorch::runtime::Result<size_t> size() const override {
4855
return size_;
4956
}
5057

58+
ET_NODISCARD executorch::runtime::Error SharedPtrDataLoader::load_into(
59+
size_t offset,
60+
size_t size,
61+
const DataLoader::SegmentInfo& segment_info,
62+
void* buffer) const {
63+
ET_CHECK_OR_RETURN_ERROR(
64+
offset + size <= size_,
65+
executorch::runtime::Error::OutOfBounds,
66+
"offset %zu + size %zu exceeds buffer size %zu",
67+
offset,
68+
size,
69+
size_);
70+
71+
std::memcpy(buffer, static_cast<uint8_t*>(data_.get()) + offset, size);
72+
return executorch::runtime::Error::Ok;
73+
}
74+
75+
5176
private:
5277
const std::shared_ptr<void> data_;
5378
const size_t size_;

extension/data_loader/test/shared_ptr_data_loader_test.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,36 @@ TEST_F(SharedPtrDataLoaderTest, OutOfBoundsLoadFails) {
140140
EXPECT_NE(fb.error(), Error::Ok);
141141
}
142142
}
143+
144+
// Unit test to check that SharedPtrDataLoader::load_into copies the correct data.
145+
TEST(SharedPtrDataLoaderTest, LoadIntoCopiesCorrectData) {
146+
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};
147+
148+
// Wrap the source data in a shared_ptr without taking ownership.
149+
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
150+
SharedPtrDataLoader loader(data_ptr, source_data.size());
151+
152+
uint8_t buffer[3] = {0};
153+
154+
// Load 3 bytes starting from offset 1 (expecting values 20, 30, 40).
155+
auto err = loader.load_into(1, 3, DataLoader::SegmentInfo{}, buffer);
156+
157+
EXPECT_EQ(err, Error::Ok);
158+
EXPECT_EQ(buffer[0], 20);
159+
EXPECT_EQ(buffer[1], 30);
160+
EXPECT_EQ(buffer[2], 40);
161+
}
162+
163+
// Unit test to verify that SharedPtrDataLoader::load_into handles out-of-bounds requests.
164+
TEST(SharedPtrDataLoaderTest, LoadIntoRejectsOutOfBoundsAccess) {
165+
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};
166+
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
167+
SharedPtrDataLoader loader(data_ptr, source_data.size());
168+
169+
uint8_t buffer[3] = {0};
170+
171+
// This should fail because offset + size = 4 + 3 = 7 > 5 (size of data).
172+
auto err = loader.load_into(4, 3, DataLoader::SegmentInfo{}, buffer);
173+
174+
EXPECT_EQ(err, Error::OutOfBounds);
175+
}

0 commit comments

Comments
 (0)