Skip to content

Commit 82cc11e

Browse files
authored
Fix crash in tdb_matrix_with_ids when there was an empty partition (#389)
1 parent c42b9a8 commit 82cc11e

File tree

2 files changed

+133
-94
lines changed

2 files changed

+133
-94
lines changed

src/include/detail/linalg/tdb_partitioned_matrix.h

Lines changed: 59 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class tdbPartitionedMatrix
279279
, relevant_parts_(relevant_parts)
280280
, squashed_indices_(size(relevant_parts_) + 1)
281281
, last_resident_part_{0} {
282+
scoped_timer _{tdb_func__ + " " + partitioned_vectors_uri_};
282283
if (relevant_parts_.size() >= indices.size()) {
283284
throw std::runtime_error(
284285
"Invalid partitioning, relevant_parts_ size (" +
@@ -287,9 +288,15 @@ class tdbPartitionedMatrix
287288
std::to_string(indices.size()) + ")");
288289
}
289290

290-
total_num_parts_ = size(relevant_parts_);
291+
tiledb_datatype_t attr_type =
292+
partitioned_vectors_schema_.attribute(0).type();
293+
if (attr_type != tiledb::impl::type_to_tiledb<T>::tiledb_type) {
294+
throw std::runtime_error(
295+
"Attribute type mismatch: " + std::to_string(attr_type) + " != " +
296+
std::to_string(tiledb::impl::type_to_tiledb<T>::tiledb_type));
297+
}
291298

292-
scoped_timer _{tdb_func__ + " " + partitioned_vectors_uri_};
299+
total_num_parts_ = size(relevant_parts_);
293300

294301
auto cell_order = partitioned_vectors_schema_.cell_order();
295302
auto tile_order = partitioned_vectors_schema_.tile_order();
@@ -420,47 +427,23 @@ class tdbPartitionedMatrix
420427
" != " + std::to_string(max_resident_parts_ + 1));
421428
}
422429

423-
// The number of resident partitions.
430+
// In a previous load() we may have read in some partitions. Start from
431+
// where we left off:
432+
// - The initial partition number of the resident partitions.
433+
const index_type first_resident_part = last_resident_part_;
434+
// - The initial index numbers of the resident columns.
435+
const index_type first_resident_col = last_resident_col_;
436+
437+
// 1. Calculate the number of resident partitions to load.
424438
size_t num_resident_parts{0};
425-
// The offset of the first partitions in the resident vectors.
426-
// Should be equal to first element of part_view_.
427-
index_type resident_part_offset{0};
428-
// The initial partition number of the resident partitions.
429-
index_type first_resident_part{0};
430-
// The initial index numbers of the resident columns.
431-
index_type first_resident_col{0};
432439
{
433-
const size_t attr_idx = 0;
434-
auto attr = partitioned_vectors_schema_.attribute(attr_idx);
435-
436-
std::string attr_name = attr.name();
437-
tiledb_datatype_t attr_type = attr.type();
438-
if (attr_type != tiledb::impl::type_to_tiledb<T>::tiledb_type) {
439-
throw std::runtime_error(
440-
"Attribute type mismatch: " + std::to_string(attr_type) + " != " +
441-
std::to_string(tiledb::impl::type_to_tiledb<T>::tiledb_type));
442-
}
443-
444-
/*
445-
* Fit as many partitions as we can into column_capacity_
446-
*/
447-
448-
// In a previous load() we may have read in some partitions. Start from
449-
// where we left off.
450-
first_resident_col = last_resident_col_;
451-
first_resident_part = last_resident_part_;
452-
453440
// Now our goal is to calculate the number of columns (i.e. vectors) that
454-
// we can read in, and set num_resident_cols_ to that.
441+
// we can read in, and set num_resident_cols_ to that. We want to fit as
442+
// many partitions as we can into column_capacity_.
455443
last_resident_part_ = first_resident_part;
456444
for (size_t i = first_resident_part; i < total_num_parts_; ++i) {
457445
auto next_part_size = squashed_indices_[i + 1] - squashed_indices_[i];
458446

459-
// Continue if this partition is empty
460-
if (next_part_size == 0) {
461-
continue;
462-
}
463-
464447
if (last_resident_col_ + next_part_size >
465448
first_resident_col + column_capacity_) {
466449
break;
@@ -482,24 +465,23 @@ class tdbPartitionedMatrix
482465

483466
// This is the number of partitions we will read in.
484467
num_resident_parts = last_resident_part_ - first_resident_part;
485-
resident_part_offset = first_resident_part;
486468
if (num_resident_parts > max_resident_parts_) {
487469
throw std::runtime_error(
488470
"Invalid partitioning, num_resident_parts " +
489471
std::to_string(num_resident_parts) + " > " +
490472
std::to_string(max_resident_parts_));
491473
}
492474

475+
if (num_resident_cols_ == 0) {
476+
return false;
477+
}
493478
if ((num_resident_cols_ == 0 && num_resident_parts != 0) ||
494479
(num_resident_cols_ != 0 && num_resident_parts == 0)) {
495480
throw std::runtime_error(
496481
"Invalid partitioning, " + std::to_string(num_resident_cols_) +
497482
" resident cols and " + std::to_string(num_resident_parts) +
498483
" resident parts");
499484
}
500-
if (num_resident_cols_ == 0) {
501-
return false;
502-
}
503485

504486
if (this->part_index_.size() != max_resident_parts_ + 1) {
505487
throw std::runtime_error(
@@ -508,19 +490,24 @@ class tdbPartitionedMatrix
508490
") != max_resident_parts_ + 1 (" +
509491
std::to_string(max_resident_parts_ + 1) + ")");
510492
}
493+
}
511494

512-
/*
513-
* Set up the subarray to read the partitions
514-
*/
495+
// 2. Load the vectors and IDs.
496+
{
497+
// a. Set up the vectors subarray.
498+
auto attr = partitioned_vectors_schema_.attribute(0);
499+
std::string attr_name = attr.name();
515500
tiledb::Subarray subarray(ctx_, *(this->partitioned_vectors_array_));
516-
517501
// For a 128 dimension vector, Dimension 0 will go from 0 to 127.
518502
auto dimension = num_array_rows_;
519503
subarray.add_range(0, 0, (int)dimension - 1);
520504

521-
/**
522-
* Read in the next batch of partitions
523-
*/
505+
// b. Set up the IDs subarray.
506+
auto ids_attr = ids_schema_.attribute(0);
507+
std::string ids_attr_name = ids_attr.name();
508+
tiledb::Subarray ids_subarray(ctx_, *partitioned_ids_array_);
509+
510+
// b. Read in the next batch of partitions
524511
size_t col_count = 0;
525512
for (size_t j = first_resident_part; j < last_resident_part_; ++j) {
526513
size_t start = master_indices_[relevant_parts_[j]];
@@ -531,81 +518,47 @@ class tdbPartitionedMatrix
531518
}
532519
col_count += len;
533520
subarray.add_range(1, (int)start, (int)stop - 1);
521+
ids_subarray.add_range(0, (int)start, (int)stop - 1);
534522
}
535523
if (col_count != last_resident_col_ - first_resident_col) {
536524
throw std::runtime_error("Column count mismatch");
537525
}
538526

539-
auto cell_order = partitioned_vectors_schema_.cell_order();
540-
auto layout_order = cell_order;
541-
527+
// c. Execute the vectors query.
542528
tiledb::Query query(ctx_, *(this->partitioned_vectors_array_));
543-
544529
auto ptr = this->data();
545530
query.set_subarray(subarray)
546-
.set_layout(layout_order)
531+
.set_layout(partitioned_vectors_schema_.cell_order())
547532
.set_data_buffer(attr_name, ptr, col_count * dimension);
548-
// tiledb_helpers::submit_query(tdb_func__, partitioned_vectors_uri_,
549-
// query);
550-
query.submit();
533+
tiledb_helpers::submit_query(tdb_func__, partitioned_vectors_uri_, query);
551534
_memory_data.insert_entry(tdb_func__, col_count * dimension * sizeof(T));
552535

553-
// assert(tiledb::Query::Status::COMPLETE == query.query_dstatus());
554536
auto qs = query.query_status();
555537
// @todo Handle incomplete queries.
556538
if (tiledb::Query::Status::COMPLETE != query.query_status()) {
557539
throw std::runtime_error("Query status is not complete -- fix me");
558540
}
559-
}
560-
561-
// Repeat for ids -- use separate scopes for partitions and ids to keep from
562-
// cross pollinating identifiers
563-
// @todo -- combine these two blocks
564-
{
565-
auto ids_attr_idx = 0;
566-
567-
auto ids_attr = ids_schema_.attribute(ids_attr_idx);
568-
std::string ids_attr_name = ids_attr.name();
569-
570-
tiledb::Subarray ids_subarray(ctx_, *partitioned_ids_array_);
571-
572-
size_t ids_col_count = 0;
573-
for (size_t j = first_resident_part; j < last_resident_part_; ++j) {
574-
size_t start = master_indices_[relevant_parts_[j]];
575-
size_t stop = master_indices_[relevant_parts_[j] + 1];
576-
size_t len = stop - start;
577-
if (len == 0) {
578-
continue;
579-
}
580-
ids_col_count += len;
581-
ids_subarray.add_range(0, (int)start, (int)stop - 1);
582-
}
583-
if (ids_col_count != last_resident_col_ - first_resident_col) {
584-
throw std::runtime_error("Column count mismatch");
585-
}
586541

542+
// d. Execute the IDs query.
587543
tiledb::Query ids_query(ctx_, *partitioned_ids_array_);
588-
589544
auto ids_ptr = this->ids_.data();
590545
ids_query.set_subarray(ids_subarray)
591-
.set_data_buffer(ids_attr_name, ids_ptr, ids_col_count);
546+
.set_data_buffer(ids_attr_name, ids_ptr, col_count);
592547
tiledb_helpers::submit_query(tdb_func__, partitioned_ids_uri_, ids_query);
593-
_memory_data.insert_entry(tdb_func__, ids_col_count * sizeof(T));
548+
_memory_data.insert_entry(tdb_func__, col_count * sizeof(T));
594549

595550
// assert(tiledb::Query::Status::COMPLETE == query.query_status());
596551
if (tiledb::Query::Status::COMPLETE != ids_query.query_status()) {
597552
throw std::runtime_error("Query status is not complete -- fix me");
598553
}
599554
}
600555

601-
/*
602-
* Copy indices for resident partitions into Base::part_index_
603-
* resident_part_offset will be the first index into squashed
604-
* Also [first_resident_part, last_resident_part_)
605-
*/
606-
auto sub = squashed_indices_[resident_part_offset];
556+
// 3. Copy indices for resident partitions into Base::part_index_
557+
// first_resident_part will be the first index into squashed
558+
// Also [first_resident_part, last_resident_part_)
559+
auto sub = squashed_indices_[first_resident_part];
607560
for (size_t i = 0; i < num_resident_parts + 1; ++i) {
608-
this->part_index_[i] = squashed_indices_[i + resident_part_offset] - sub;
561+
this->part_index_[i] = squashed_indices_[i + first_resident_part] - sub;
609562
}
610563

611564
this->num_vectors_ = num_resident_cols_;
@@ -626,6 +579,19 @@ class tdbPartitionedMatrix
626579
partitioned_ids_array_->close();
627580
}
628581
}
582+
583+
void debug_tdb_partitioned_matrix(const std::string& msg, size_t max_size) {
584+
debug_partitioned_matrix(*this, msg, max_size);
585+
debug_vector(master_indices_, "# master_indices_", max_size);
586+
debug_vector(relevant_parts_, "# relevant_parts_", max_size);
587+
debug_vector(squashed_indices_, "# squashed_indices_", max_size);
588+
std::cout << "# total_num_parts_: " << total_num_parts_ << std::endl;
589+
std::cout << "# last_resident_part_: " << last_resident_part_ << std::endl;
590+
std::cout << "# column_capacity_: " << column_capacity_ << std::endl;
591+
std::cout << "# num_resident_cols_: " << num_resident_cols_ << std::endl;
592+
std::cout << "# last_resident_col_: " << last_resident_col_ << std::endl;
593+
std::cout << "# max_resident_parts_: " << max_resident_parts_ << std::endl;
594+
}
629595
};
630596

631597
/**

src/include/test/unit_tdb_partitioned_matrix.cc

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "cpos.h"
3636
#include "detail/linalg/matrix.h"
3737
#include "detail/linalg/tdb_io.h"
38+
#include "detail/linalg/tdb_matrix_with_ids.h"
3839
#include "detail/linalg/tdb_partitioned_matrix.h"
3940
#include "mdspan/mdspan.hpp"
4041

@@ -85,7 +86,6 @@ std::vector<std::vector<T>> generateSubsets(int num_parts) {
8586
}
8687

8788
TEST_CASE("can load correctly", "[tdb_partitioned_matrix]") {
88-
return;
8989
tiledb::Context ctx;
9090
tiledb::VFS vfs(ctx);
9191

@@ -316,3 +316,76 @@ TEST_CASE("test different combinations", "[tdb_partitioned_matrix]") {
316316
}
317317
}
318318
}
319+
320+
TEST_CASE(
321+
"tdb_partitioned_matrix: empty partition", "[tdb_partitioned_matrix]") {
322+
tiledb::Context ctx;
323+
tiledb::VFS vfs(ctx);
324+
325+
using feature_type = uint64_t;
326+
using id_type = uint64_t;
327+
using part_index_type = uint64_t;
328+
329+
std::string partitioned_vectors_uri =
330+
(std::filesystem::temp_directory_path() / "partitioned_vectors").string();
331+
std::string ids_uri =
332+
(std::filesystem::temp_directory_path() / "ids").string();
333+
334+
size_t num_vectors = 10000;
335+
size_t dimensions = 128;
336+
337+
// Setup data.
338+
{
339+
if (vfs.is_dir(partitioned_vectors_uri)) {
340+
vfs.remove_dir(partitioned_vectors_uri);
341+
}
342+
if (vfs.is_dir(ids_uri)) {
343+
vfs.remove_dir(ids_uri);
344+
}
345+
346+
auto partitioned_vectors =
347+
ColMajorMatrix<feature_type>(dimensions, num_vectors);
348+
for (size_t i = 0; i < dimensions; ++i) {
349+
for (size_t j = 0; j < num_vectors; ++j) {
350+
partitioned_vectors(i, j) = j;
351+
}
352+
}
353+
write_matrix(ctx, partitioned_vectors, partitioned_vectors_uri);
354+
std::vector<id_type> ids(num_vectors, 0);
355+
for (size_t i = 0; i < num_vectors; ++i) {
356+
ids[i] = i;
357+
}
358+
write_vector(ctx, ids, ids_uri);
359+
}
360+
361+
// Test that we do not crash if we have an empty part (i.e. two elements in
362+
// indices with the same value). These values were taken from running
363+
// `api_ivf_flat_index: read index and query infinite and finite - finite out
364+
// of core, 1000, nprobe: 32, max_iter: 8` which used to crash with these
365+
// values.
366+
std::vector<part_index_type> relevant_parts = {
367+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
368+
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34,
369+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
370+
51, 52, 53, 55, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 68,
371+
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
372+
85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99};
373+
std::vector<part_index_type> indices = {
374+
0, 1, 116, 215, 318, 418, 600, 662, 862, 1041, 1176, 1248,
375+
1349, 1488, 1612, 1754, 1877, 1878, 1880, 2028, 2135, 2228, 2328, 2330,
376+
2464, 2526, 2682, 2785, 2911, 3059, 3191, 3192, 3266, 3395, 3516, 3607,
377+
3757, 3758, 3861, 3998, 4100, 4306, 4446, 4618, 4733, 4838, 4958, 5112,
378+
5169, 5277, 5372, 5466, 5653, 5729, 5810, 5811, 5977, 6056, 6057, 6266,
379+
6269, 6337, 6338, 6338, 6437, 6570, 6660, 6727, 6820, 6900, 7004, 7138,
380+
7139, 7220, 7227, 7339, 7414, 7539, 7695, 7781, 8004, 8095, 8161, 8235,
381+
8320, 8389, 8495, 8619, 8769, 8840, 9043, 9088, 9183, 9241, 9293, 9425,
382+
9548, 9625, 9743, 9880, 10000};
383+
auto matrix =
384+
tdbColMajorPartitionedMatrix<feature_type, id_type, part_index_type>(
385+
ctx, partitioned_vectors_uri, indices, ids_uri, relevant_parts, 1000);
386+
while (matrix.load()) {
387+
CHECK(matrix.num_vectors() > 0);
388+
CHECK(matrix.num_partitions() > 0);
389+
CHECK(_cpo::dimensions(matrix) == dimensions);
390+
}
391+
}

0 commit comments

Comments
 (0)