Skip to content

Commit 7515329

Browse files
authored
Fix bug in tdb matrix where we would read even when the array is empty and get nan values (#368)
1 parent f4d08e3 commit 7515329

File tree

6 files changed

+198
-27
lines changed

6 files changed

+198
-27
lines changed

src/include/detail/linalg/tdb_matrix.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,40 +239,35 @@ class tdbBlockedMatrix : public MatrixBase {
239239
throw std::runtime_error("Cell order and matrix order must match");
240240
}
241241

242-
// @todo Maybe throw an exception here instead of just an assert?
243-
// Have to properly handle an exception since this is a constructor.
244-
assert(cell_order == tile_order);
245-
246-
const size_t attr_idx = 0;
242+
if (cell_order != tile_order) {
243+
throw std::runtime_error("Cell order and tile order must match");
244+
}
247245

248246
auto domain_{schema_.domain()};
249247

250248
auto row_domain{domain_.dimension(0)};
251249
auto col_domain{domain_.dimension(1)};
252250

253-
// If the user specifies a value then we use it, otherwise we use the
254-
// non-empty domain. If non_empty_domain() is an empty vector it means that
255-
// the array is empty.
251+
// If non_empty_domain() is an empty vector it means that
252+
// the array is empty. Else If the user specifies a value then we use it,
253+
// otherwise we use the non-empty domain.
256254
auto non_empty_domain = array_->non_empty_domain<int>();
257-
if (!last_row.has_value()) {
258-
if (non_empty_domain.empty()) {
259-
last_row_ = 0;
255+
if (non_empty_domain.empty()) {
256+
last_row_ = 0;
257+
last_col_ = 0;
258+
} else {
259+
if (last_row.has_value()) {
260+
last_row_ = *last_row;
260261
} else {
261262
last_row_ = non_empty_domain[0].second.second -
262263
non_empty_domain[0].second.first + 1;
263264
}
264-
} else {
265-
last_row_ = *last_row;
266-
}
267-
if (!last_col.has_value()) {
268-
if (non_empty_domain.empty()) {
269-
last_col_ = 0;
265+
if (last_col.has_value()) {
266+
last_col_ = *last_col;
270267
} else {
271268
last_col_ = non_empty_domain[1].second.second -
272269
non_empty_domain[1].second.first + 1;
273270
}
274-
} else {
275-
last_col_ = *last_col;
276271
}
277272

278273
size_t dimension = last_row_ - first_row_;

src/include/test/test_utils.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ void fill_and_write_matrix(
6262
const std::string& uri,
6363
size_t rows,
6464
size_t cols,
65-
size_t offset) {
65+
size_t offset,
66+
TemporalPolicy temporal_policy = {}) {
6667
tiledb::VFS vfs(ctx);
6768
if (vfs.is_dir(uri)) {
6869
vfs.remove_dir(uri);
6970
}
7071
std::iota(X.data(), X.data() + dimension(X) * num_vectors(X), offset);
71-
write_matrix(ctx, X, uri);
72+
write_matrix(ctx, X, uri, 0, true, temporal_policy);
7273
}
7374

7475
/*
@@ -91,7 +92,8 @@ void fill_and_write_matrix(
9192
const std::string& ids_uri,
9293
size_t rows,
9394
size_t cols,
94-
size_t offset) {
95+
size_t offset,
96+
TemporalPolicy temporal_policy = {}) {
9597
tiledb::VFS vfs(ctx);
9698
if (vfs.is_dir(uri)) {
9799
vfs.remove_dir(uri);
@@ -103,10 +105,10 @@ void fill_and_write_matrix(
103105
std::iota(X.ids().begin(), X.ids().end(), offset);
104106

105107
// Write the vectors to their URI.
106-
write_matrix(ctx, X, uri);
108+
write_matrix(ctx, X, uri, 0, true, temporal_policy);
107109

108110
// Write the IDs to their URI.
109-
write_vector(ctx, X.ids(), ids_uri);
111+
write_vector(ctx, X.ids(), ids_uri, 0, true, temporal_policy);
110112
}
111113

112114
void validate_metadata(

src/include/test/unit_api_feature_vector_array.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ TEST_CASE("api: temporal_policy", "[api]") {
664664
}
665665
}
666666

667-
// Read the data at timestamp 99 explicitly.
667+
// Read the data at timestamp 99.
668668
{
669669
auto feature_vector_array = FeatureVectorArray(
670670
ctx, feature_vectors_uri, ids_uri, 0, TemporalPolicy(TimeTravel, 99));
@@ -684,4 +684,24 @@ TEST_CASE("api: temporal_policy", "[api]") {
684684
}
685685
}
686686
}
687+
688+
// Read the data at timestamp 50.
689+
{
690+
auto feature_vector_array = FeatureVectorArray(
691+
ctx, feature_vectors_uri, ids_uri, 0, TemporalPolicy(TimeTravel, 50));
692+
CHECK(extents(feature_vector_array)[0] == 0);
693+
CHECK(extents(feature_vector_array)[1] == 0);
694+
CHECK(feature_vector_array.num_vectors() == 0);
695+
CHECK(feature_vector_array.num_ids() == 0);
696+
CHECK(feature_vector_array.dimension() == 0);
697+
auto data = MatrixView<FeatureType, stdx::layout_left>{
698+
(FeatureType*)feature_vector_array.data(),
699+
extents(feature_vector_array)[0],
700+
extents(feature_vector_array)[1]};
701+
auto ids = std::span<IdsType>(
702+
(IdsType*)feature_vector_array.ids_data(),
703+
feature_vector_array.num_vectors());
704+
CHECK(ids.size() == 0);
705+
CHECK(data.size() == 0);
706+
}
687707
}

src/include/test/unit_tdb_io.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ TEST_CASE("tdb_io: write empty matrix", "[tdb_io]") {
174174
empty_matrix.load();
175175
CHECK(num_vectors(empty_matrix) == 0);
176176
CHECK(empty_matrix.num_cols() == 0);
177-
CHECK(empty_matrix.num_rows() == dimension);
177+
CHECK(empty_matrix.num_rows() == 0);
178178

179179
auto empty_preload_matrix =
180180
tdbColMajorPreLoadMatrix<float>(ctx, tmp_matrix_uri, dimension, 0, 0, {});
181181
CHECK(num_vectors(empty_preload_matrix) == 0);
182182
CHECK(empty_preload_matrix.num_cols() == 0);
183-
CHECK(empty_preload_matrix.num_rows() == dimension);
183+
CHECK(empty_preload_matrix.num_rows() == 0);
184184
}
185185

186186
TEST_CASE("tdb_io: write empty vector", "[tdb_io]") {

src/include/test/unit_tdb_matrix.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,70 @@ TEST_CASE("tdb_matrix: empty matrix", "[tdb_matrix]") {
336336
CHECK(dimension(X) == 0);
337337
}
338338
}
339+
340+
TEST_CASE("tdb_matrix: time travel", "[tdb_matrix]") {
341+
tiledb::Context ctx;
342+
std::string tmp_matrix_uri =
343+
(std::filesystem::temp_directory_path() / "tmp_tdb_matrix").string();
344+
int offset = 13;
345+
346+
size_t Mrows = 20;
347+
size_t Ncols = 50;
348+
349+
tiledb::VFS vfs(ctx);
350+
if (vfs.is_dir(tmp_matrix_uri)) {
351+
vfs.remove_dir(tmp_matrix_uri);
352+
}
353+
354+
auto X = ColMajorMatrix<int>(Mrows, Ncols);
355+
std::iota(X.data(), X.data() + dimension(X) * num_vectors(X), offset);
356+
write_matrix(ctx, X, tmp_matrix_uri, 0, true, TemporalPolicy{TimeTravel, 50});
357+
358+
{
359+
// We can load the matrix at the creation timestamp.
360+
auto Y = tdbPreLoadMatrix<int, stdx::layout_left>(
361+
ctx, tmp_matrix_uri, 0, TemporalPolicy{TimeTravel, 50});
362+
CHECK(num_vectors(Y) == num_vectors(X));
363+
CHECK(dimension(Y) == dimension(X));
364+
CHECK(std::equal(
365+
X.data(), X.data() + dimension(X) * num_vectors(X), Y.data()));
366+
for (size_t i = 0; i < Mrows; ++i) {
367+
for (size_t j = 0; j < Ncols; ++j) {
368+
CHECK(X(i, j) == Y(i, j));
369+
}
370+
}
371+
}
372+
373+
{
374+
// We can load the matrix at a later timestamp.
375+
auto Y = tdbPreLoadMatrix<int, stdx::layout_left>(
376+
ctx, tmp_matrix_uri, 0, TemporalPolicy{TimeTravel, 100});
377+
CHECK(num_vectors(Y) == num_vectors(X));
378+
CHECK(dimension(Y) == dimension(X));
379+
CHECK(std::equal(
380+
X.data(), X.data() + dimension(X) * num_vectors(X), Y.data()));
381+
for (size_t i = 0; i < Mrows; ++i) {
382+
for (size_t j = 0; j < Ncols; ++j) {
383+
CHECK(X(i, j) == Y(i, j));
384+
}
385+
}
386+
}
387+
388+
{
389+
// We get no data if we load the matrix at an earlier timestamp.
390+
auto Y = tdbPreLoadMatrix<int, stdx::layout_left>(
391+
ctx, tmp_matrix_uri, 0, TemporalPolicy{TimeTravel, 5});
392+
CHECK(num_vectors(Y) == 0);
393+
CHECK(dimension(Y) == 0);
394+
}
395+
396+
{
397+
// We get no data if we load the matrix at an earlier timestamp, even if we
398+
// specify we want to read 4 rows and 2 cols.
399+
auto Y = tdbPreLoadMatrix<int, stdx::layout_left>(
400+
ctx, tmp_matrix_uri, 4, 2, 0, TemporalPolicy{TimeTravel, 5});
401+
CHECK(num_vectors(Y) == 0);
402+
CHECK(dimension(Y) == 0);
403+
CHECK(Y.size() == 0);
404+
}
405+
}

src/include/test/unit_tdb_matrix_with_ids.cc

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,90 @@ TEMPLATE_TEST_CASE(
368368
CHECK(X.ids()[i] == Z.ids()[i]);
369369
}
370370
}
371+
372+
TEST_CASE("tdb_matrix_with_ids: time travel", "[tdb_matrix_with_ids]") {
373+
tiledb::Context ctx;
374+
std::string tmp_matrix_uri =
375+
(std::filesystem::temp_directory_path() / "tmp_tdb_matrix").string();
376+
std::string tmp_ids_uri =
377+
(std::filesystem::temp_directory_path() / "tmp_ids_vector").string();
378+
379+
int offset = 13;
380+
381+
size_t Mrows = 40;
382+
size_t Ncols = 20;
383+
384+
tiledb::VFS vfs(ctx);
385+
if (vfs.is_dir(tmp_matrix_uri)) {
386+
vfs.remove_dir(tmp_matrix_uri);
387+
}
388+
if (vfs.is_dir(tmp_ids_uri)) {
389+
vfs.remove_dir(tmp_ids_uri);
390+
}
391+
392+
auto X = ColMajorMatrixWithIds<float, uint64_t, size_t>(Mrows, Ncols);
393+
fill_and_write_matrix(
394+
ctx,
395+
X,
396+
tmp_matrix_uri,
397+
tmp_ids_uri,
398+
Mrows,
399+
Ncols,
400+
offset,
401+
TemporalPolicy{TimeTravel, 50});
402+
403+
{
404+
// We can load the matrix at the creation timestamp.
405+
auto Y = tdbColMajorPreLoadMatrixWithIds<float, uint64_t, size_t>(
406+
ctx, tmp_matrix_uri, tmp_ids_uri, 0, TemporalPolicy{TimeTravel, 50});
407+
CHECK(num_vectors(Y) == num_vectors(X));
408+
CHECK(dimension(Y) == dimension(X));
409+
CHECK(std::equal(
410+
X.data(), X.data() + dimension(X) * num_vectors(X), Y.data()));
411+
for (size_t i = 0; i < Mrows; ++i) {
412+
for (size_t j = 0; j < Ncols; ++j) {
413+
CHECK(X(i, j) == Y(i, j));
414+
}
415+
}
416+
}
417+
418+
{
419+
// We can load the matrix at a later timestamp.
420+
auto Y = tdbColMajorPreLoadMatrixWithIds<float, uint64_t, size_t>(
421+
ctx, tmp_matrix_uri, tmp_ids_uri, 0, TemporalPolicy{TimeTravel, 100});
422+
CHECK(num_vectors(Y) == num_vectors(X));
423+
CHECK(dimension(Y) == dimension(X));
424+
CHECK(std::equal(
425+
X.data(), X.data() + dimension(X) * num_vectors(X), Y.data()));
426+
for (size_t i = 0; i < Mrows; ++i) {
427+
for (size_t j = 0; j < Ncols; ++j) {
428+
CHECK(X(i, j) == Y(i, j));
429+
}
430+
}
431+
}
432+
433+
{
434+
// We get no data if we load the matrix at an earlier timestamp.
435+
auto Y = tdbColMajorPreLoadMatrixWithIds<float, uint64_t, size_t>(
436+
ctx, tmp_matrix_uri, tmp_ids_uri, 0, TemporalPolicy{TimeTravel, 5});
437+
CHECK(num_vectors(Y) == 0);
438+
CHECK(dimension(Y) == 0);
439+
CHECK(Y.size() == 0);
440+
}
441+
442+
{
443+
// We get no data if we load the matrix at an earlier timestamp, even if we
444+
// specify we want to read 4 rows and 2 cols.
445+
auto Y = tdbColMajorPreLoadMatrixWithIds<float, uint64_t, size_t>(
446+
ctx,
447+
tmp_matrix_uri,
448+
tmp_ids_uri,
449+
4,
450+
2,
451+
0,
452+
TemporalPolicy{TimeTravel, 5});
453+
CHECK(num_vectors(Y) == 0);
454+
CHECK(dimension(Y) == 0);
455+
CHECK(Y.size() == 0);
456+
}
457+
}

0 commit comments

Comments
 (0)