Skip to content

Commit ba2c962

Browse files
authored
Add tdbPreLoadMatrixWithIds (#284)
1 parent f57684b commit ba2c962

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

src/include/detail/linalg/tdb_matrix_with_ids.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,63 @@ class tdbBlockedMatrixWithIds
241241
}
242242
}; // tdbBlockedMatrixWithIds
243243

244+
template <
245+
class T,
246+
class IdsType = uint64_t,
247+
class LayoutPolicy = stdx::layout_right,
248+
class I = size_t>
249+
class tdbPreLoadMatrixWithIds
250+
: public tdbBlockedMatrixWithIds<T, IdsType, LayoutPolicy, I> {
251+
using Base = tdbBlockedMatrixWithIds<T, IdsType, LayoutPolicy, I>;
252+
253+
public:
254+
/**
255+
* @brief Construct a new tdbBlockedMatrixWithIds object, limited to
256+
* `upper_bound` vectors. In this case, the `Matrix` is column-major, so the
257+
* number of vectors is the number of columns.
258+
*
259+
* @param ctx The TileDB context to use.
260+
* @param uri URI of the TileDB array to read.
261+
* @param upper_bound The maximum number of vectors to read.
262+
*/
263+
tdbPreLoadMatrixWithIds(
264+
const tiledb::Context& ctx,
265+
const std::string& uri,
266+
const std::string& ids_uri,
267+
size_t upper_bound = 0,
268+
uint64_t timestamp = 0)
269+
: tdbPreLoadMatrixWithIds(
270+
ctx,
271+
uri,
272+
ids_uri,
273+
std::nullopt,
274+
std::nullopt,
275+
upper_bound,
276+
timestamp) {
277+
}
278+
279+
tdbPreLoadMatrixWithIds(
280+
const tiledb::Context& ctx,
281+
const std::string& uri,
282+
const std::string& ids_uri,
283+
std::optional<size_t> num_array_rows,
284+
std::optional<size_t> num_array_cols,
285+
size_t upper_bound = 0,
286+
uint64_t timestamp = 0)
287+
: Base(
288+
ctx,
289+
uri,
290+
ids_uri,
291+
0,
292+
num_array_rows,
293+
0,
294+
num_array_cols,
295+
upper_bound,
296+
timestamp) {
297+
Base::load();
298+
}
299+
};
300+
244301
/**
245302
* Convenience class for row-major blocked matrices.
246303
*/
@@ -279,4 +336,18 @@ template <
279336
class I = size_t>
280337
using tdbMatrixWithIds = tdbBlockedMatrixWithIds<T, IdsType, LayoutPolicy, I>;
281338

339+
/**
340+
* Convenience class for row-major matrices.
341+
*/
342+
template <class T, class IdsType = uint64_t, class I = size_t>
343+
using tdbRowMajorPreLoadMatrixWithIds =
344+
tdbPreLoadMatrixWithIds<T, IdsType, stdx::layout_right, I>;
345+
346+
/**
347+
* Convenience class for column-major matrices.
348+
*/
349+
template <class T, class IdsType = uint64_t, class I = size_t>
350+
using tdbColMajorPreLoadMatrixWithIds =
351+
tdbPreLoadMatrixWithIds<T, IdsType, stdx::layout_left, I>;
352+
282353
#endif // TDB_MATRIX_WITH_IDS_H

src/include/test/unit_tdb_matrix_with_ids.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,67 @@ TEST_CASE("tdb_matrix_with_ids: empty matrix", "[tdb_matrix_with_ids]") {
308308
CHECK(Y.ids().size() == 1000);
309309
}
310310
}
311+
312+
TEMPLATE_TEST_CASE(
313+
"tdb_matrix_with_ids: preload", "[tdb_matrix_with_ids]", float, uint8_t) {
314+
tiledb::Context ctx;
315+
std::string tmp_matrix_uri =
316+
(std::filesystem::temp_directory_path() / "tmp_tdb_matrix").string();
317+
std::string tmp_ids_uri =
318+
(std::filesystem::temp_directory_path() / "tmp_tdb_ids_matrix").string();
319+
int offset = 13;
320+
size_t Mrows = 200;
321+
size_t Ncols = 500;
322+
323+
tiledb::VFS vfs(ctx);
324+
if (vfs.is_dir(tmp_matrix_uri)) {
325+
vfs.remove_dir(tmp_matrix_uri);
326+
}
327+
328+
auto X = ColMajorMatrixWithIds<TestType, TestType>(Mrows, Ncols);
329+
fill_and_write_matrix(
330+
ctx, X, tmp_matrix_uri, tmp_ids_uri, Mrows, Ncols, offset);
331+
CHECK(X.ids()[0] == offset + 0);
332+
CHECK(X.ids()[1] == offset + 1);
333+
CHECK(X.ids()[10] == offset + 10);
334+
335+
auto Y = tdbPreLoadMatrixWithIds<TestType, TestType, stdx::layout_left>(
336+
ctx, tmp_matrix_uri, tmp_ids_uri);
337+
CHECK(num_vectors(Y) == num_vectors(X));
338+
CHECK(dimension(Y) == dimension(X));
339+
CHECK(
340+
std::equal(X.data(), X.data() + dimension(X) * num_vectors(X), Y.data()));
341+
for (size_t i = 0; i < 5; ++i) {
342+
for (size_t j = 0; j < 5; ++j) {
343+
CHECK(X(i, j) == Y(i, j));
344+
}
345+
}
346+
347+
CHECK(size(Y.ids()) == Y.num_ids());
348+
CHECK(size(X.ids()) == X.num_ids());
349+
CHECK(X.num_ids() == Y.num_ids());
350+
CHECK(std::equal(X.ids().begin(), X.ids().end(), Y.ids().begin()));
351+
for (size_t i = 0; i < X.num_ids(); ++i) {
352+
CHECK(X.ids()[i] == Y.ids()[i]);
353+
}
354+
355+
auto Z = ColMajorMatrixWithIds<TestType, TestType>(0, 0);
356+
Z = std::move(Y);
357+
358+
CHECK(num_vectors(Z) == num_vectors(X));
359+
CHECK(dimension(Z) == dimension(X));
360+
CHECK(
361+
std::equal(X.data(), X.data() + dimension(X) * num_vectors(X), Z.data()));
362+
for (size_t i = 0; i < 5; ++i) {
363+
for (size_t j = 0; j < 5; ++j) {
364+
CHECK(X(i, j) == Z(i, j));
365+
}
366+
}
367+
368+
CHECK(size(Z.ids()) == Z.num_ids());
369+
CHECK(X.num_ids() == Z.num_ids());
370+
CHECK(std::equal(X.ids().begin(), X.ids().end(), Z.ids().begin()));
371+
for (size_t i = 0; i < X.num_ids(); ++i) {
372+
CHECK(X.ids()[i] == Z.ids()[i]);
373+
}
374+
}

0 commit comments

Comments
 (0)