Skip to content

Commit 4f3b35b

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Fix issues with named data map load_into
Summary: Fix a couple bugs in load_into and add tests Differential Revision: D70186266
1 parent fcb40f1 commit 4f3b35b

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
156156
return tensor_layout.error();
157157
}
158158
ET_CHECK_OR_RETURN_ERROR(
159-
size < tensor_layout.get().nbytes(),
159+
size <= tensor_layout.get().nbytes(),
160160
InvalidArgument,
161161
"Buffer size %zu is smaller than tensor size %zu",
162162
size,
@@ -170,12 +170,16 @@ ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
170170
// Load mutable data.
171171
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
172172
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
173-
return loader_->load_into(
173+
auto e = loader_->load_into(
174174
header_.segment_base_offset + segment_offset.get() +
175175
metadata.get()->offset(),
176176
tensor_layout.get().nbytes(),
177177
info,
178178
buffer);
179+
if (e != Error::Ok) {
180+
return e;
181+
}
182+
return tensor_layout.get().nbytes();
179183
}
180184

181185
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
@@ -187,6 +191,7 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
187191
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
188192
return Error::InvalidArgument;
189193
}
194+
190195
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
191196
}
192197

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,26 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
137137
Result<const char*> key2_res = data_map->get_key(2);
138138
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
139139
}
140+
141+
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
142+
Result<FlatTensorDataMap> data_map =
143+
FlatTensorDataMap::load(data_map_loader_.get());
144+
EXPECT_EQ(data_map.error(), Error::Ok);
145+
146+
// get the metadata
147+
auto meta_data_res = data_map->get_metadata("a");
148+
ASSERT_EQ(meta_data_res.error(), Error::Ok);
149+
150+
// get data blob
151+
void* data = malloc(meta_data_res->nbytes());
152+
auto load_into_res =
153+
data_map->load_data_into("a", data, meta_data_res->nbytes());
154+
ASSERT_EQ(load_into_res.error(), Error::Ok);
155+
156+
// Check tensor data is correct.
157+
float* data_a = static_cast<float*>(data);
158+
for (int i = 0; i < 4; i++) {
159+
EXPECT_EQ(data_a[i], 3.0);
160+
}
161+
free(data);
162+
}

0 commit comments

Comments
 (0)