Skip to content

Commit bacd7e8

Browse files
authored
Add support for reading an empty vector with read_vector() (#269)
1 parent e0de0a1 commit bacd7e8

File tree

3 files changed

+104
-55
lines changed

3 files changed

+104
-55
lines changed

src/include/detail/linalg/tdb_io.h

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,77 @@
4343
#include "utils/print_types.h"
4444
#include "utils/timer.h"
4545

46+
namespace {
47+
48+
template <class T>
49+
std::vector<T> read_vector_helper(
50+
const tiledb::Context& ctx,
51+
const std::string& uri,
52+
size_t start_pos,
53+
size_t end_pos,
54+
size_t timestamp,
55+
bool read_full_vector) {
56+
scoped_timer _{tdb_func__ + " " + std::string{uri}};
57+
58+
tiledb::TemporalPolicy temporal_policy =
59+
(timestamp == 0) ? tiledb::TemporalPolicy() :
60+
tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
61+
62+
auto array_ = tiledb_helpers::open_array(
63+
tdb_func__, ctx, uri, TILEDB_READ, temporal_policy);
64+
auto schema_ = array_->schema();
65+
66+
using domain_type = int32_t;
67+
const size_t idx = 0;
68+
69+
auto domain_{schema_.domain()};
70+
71+
auto dim_num_{domain_.ndim()};
72+
auto array_rows_{domain_.dimension(0)};
73+
74+
if (read_full_vector) {
75+
if (start_pos == 0) {
76+
start_pos = array_rows_.template domain<domain_type>().first;
77+
}
78+
if (end_pos == 0) {
79+
end_pos = array_rows_.template domain<domain_type>().second + 1;
80+
}
81+
}
82+
83+
auto vec_rows_{end_pos - start_pos};
84+
85+
if (vec_rows_ == 0) {
86+
return {};
87+
}
88+
89+
auto attr_num{schema_.attribute_num()};
90+
auto attr = schema_.attribute(idx);
91+
92+
std::string attr_name = attr.name();
93+
tiledb_datatype_t attr_type = attr.type();
94+
95+
// Create a subarray that reads the array up to the specified subset.
96+
std::vector<int32_t> subarray_vals = {
97+
(int32_t)start_pos, (int32_t)end_pos - 1};
98+
tiledb::Subarray subarray(ctx, *array_);
99+
subarray.set_subarray(subarray_vals);
100+
101+
// @todo: use something non-initializing
102+
std::vector<T> data_(vec_rows_);
103+
104+
tiledb::Query query(ctx, *array_);
105+
query.set_subarray(subarray).set_data_buffer(
106+
attr_name, data_.data(), vec_rows_);
107+
tiledb_helpers::submit_query(tdb_func__, uri, query);
108+
_memory_data.insert_entry(tdb_func__, vec_rows_ * sizeof(T));
109+
110+
array_->close();
111+
assert(tiledb::Query::Status::COMPLETE == query.query_status());
112+
113+
return data_;
114+
}
115+
} // namespace
116+
46117
/******************************************************************************
47118
* Matrix creation and writing. Because we support out-of-core operations with
48119
* matrices, reading a matrix is more subtle and is contained in the tdb_matrix
@@ -290,67 +361,16 @@ std::vector<T> read_vector(
290361
size_t start_pos,
291362
size_t end_pos,
292363
size_t timestamp = 0) {
293-
scoped_timer _{tdb_func__ + " " + std::string{uri}};
294-
295-
tiledb::TemporalPolicy temporal_policy =
296-
(timestamp == 0) ? tiledb::TemporalPolicy() :
297-
tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp);
298-
299-
auto array_ = tiledb_helpers::open_array(
300-
tdb_func__, ctx, uri, TILEDB_READ, temporal_policy);
301-
auto schema_ = array_->schema();
302-
303-
using domain_type = int32_t;
304-
const size_t idx = 0;
305-
306-
auto domain_{schema_.domain()};
307-
308-
auto dim_num_{domain_.ndim()};
309-
auto array_rows_{domain_.dimension(0)};
310-
311-
if (start_pos == 0) {
312-
start_pos = array_rows_.template domain<domain_type>().first;
313-
}
314-
if (end_pos == 0) {
315-
end_pos = array_rows_.template domain<domain_type>().second + 1;
316-
}
317-
318-
auto vec_rows_{end_pos - start_pos};
319-
320-
auto attr_num{schema_.attribute_num()};
321-
auto attr = schema_.attribute(idx);
322-
323-
std::string attr_name = attr.name();
324-
tiledb_datatype_t attr_type = attr.type();
325-
326-
// Create a subarray that reads the array up to the specified subset.
327-
std::vector<int32_t> subarray_vals = {
328-
(int32_t)start_pos, (int32_t)end_pos - 1};
329-
tiledb::Subarray subarray(ctx, *array_);
330-
subarray.set_subarray(subarray_vals);
331-
332-
// @todo: use something non-initializing
333-
std::vector<T> data_(vec_rows_);
334-
335-
tiledb::Query query(ctx, *array_);
336-
query.set_subarray(subarray).set_data_buffer(
337-
attr_name, data_.data(), vec_rows_);
338-
tiledb_helpers::submit_query(tdb_func__, uri, query);
339-
_memory_data.insert_entry(tdb_func__, vec_rows_ * sizeof(T));
340-
341-
array_->close();
342-
assert(tiledb::Query::Status::COMPLETE == query.query_status());
343-
344-
return data_;
364+
return read_vector_helper<T>(ctx, uri, start_pos, end_pos, timestamp, false);
345365
}
346366

347367
/**
348-
* @brief Overload with 0 for start_pos and end_pos (read entire vector)
368+
* @brief Overload to read the entire vector.
349369
*/
350370
template <class T>
351371
std::vector<T> read_vector(
352372
const tiledb::Context& ctx, const std::string& uri, size_t timestamp = 0) {
353-
return read_vector<T>(ctx, uri, 0, 0, timestamp);
373+
return read_vector_helper<T>(ctx, uri, 0, 0, timestamp, true);
354374
}
355375

356376
template <class T>

src/include/detail/linalg/tdb_vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class tdbVector : public Vector<T> {
4848

4949
public:
5050
tdbVector(const tiledb::Context& ctx, const std::string& uri)
51-
: Base(read_vector<T>(ctx, uri, 0, 0, 0)) {
51+
: Base(read_vector<T>(ctx, uri)) {
5252
}
5353
};
5454

src/include/test/unit_tdb_io.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,35 @@ TEMPLATE_TEST_CASE("tdb_io: write matrix", "[tdb_io]", float, uint8_t) {
151151
}
152152
}
153153

154+
TEST_CASE("tdb_io: write empty vector", "[tdb_io]") {
155+
tiledb::Context ctx;
156+
std::string tmp_vector_uri =
157+
(std::filesystem::temp_directory_path() / "tmp_vector").string();
158+
int offset = 13;
159+
160+
size_t dimension = 128;
161+
static const int32_t domain{10000};
162+
static const int32_t tile_size_bytes{1024};
163+
static const tiledb_filter_type_t compression{string_to_filter("zstd")};
164+
static const int32_t tile_size{
165+
(int32_t)(tile_size_bytes / sizeof(float) / dimension)};
166+
size_t timestamp = 0;
167+
168+
tiledb::VFS vfs(ctx);
169+
if (vfs.is_dir(tmp_vector_uri)) {
170+
vfs.remove_dir(tmp_vector_uri);
171+
}
172+
173+
create_empty_for_vector<float>(
174+
ctx, tmp_vector_uri, domain, tile_size, compression);
175+
176+
auto empty_vector = read_vector<float>(ctx, tmp_vector_uri, 0, 0, timestamp);
177+
CHECK(empty_vector.size() == 0);
178+
179+
auto filled_vector = read_vector<float>(ctx, tmp_vector_uri);
180+
CHECK(filled_vector.size() == domain);
181+
}
182+
154183
TEST_CASE("tdb_io: create group", "[tdb_io]") {
155184
size_t N = 10'000;
156185

0 commit comments

Comments
 (0)