Skip to content

Commit fac90e4

Browse files
authored
Save kmeans settings to IVF PQ metadata (#452)
1 parent e7ae919 commit fac90e4

File tree

10 files changed

+229
-103
lines changed

10 files changed

+229
-103
lines changed

src/include/api/ivf_pq_index.h

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ class IndexIVFPQ {
106106
max_iterations_ = std::stol(value);
107107
} else if (key == "convergence_tolerance") {
108108
convergence_tolerance_ = std::stof(value);
109+
} else if (key == "reassign_ratio") {
110+
reassign_ratio_ = std::stof(value);
109111
} else if (key == "feature_type") {
110112
feature_datatype_ = string_to_datatype(value);
111113
} else if (key == "id_type") {
@@ -150,6 +152,7 @@ class IndexIVFPQ {
150152
num_subspaces_ = index_->num_subspaces();
151153
max_iterations_ = index_->max_iterations();
152154
convergence_tolerance_ = index_->convergence_tolerance();
155+
reassign_ratio_ = index_->reassign_ratio();
153156

154157
if (dimensions_ != 0 && dimensions_ != index_->dimensions()) {
155158
throw std::runtime_error(
@@ -190,6 +193,7 @@ class IndexIVFPQ {
190193
num_subspaces_,
191194
max_iterations_,
192195
convergence_tolerance_,
196+
reassign_ratio_,
193197
index_ ? std::make_optional<TemporalPolicy>(index_->temporal_policy()) :
194198
std::nullopt);
195199

@@ -295,6 +299,10 @@ class IndexIVFPQ {
295299
return convergence_tolerance_;
296300
}
297301

302+
constexpr auto reassign_ratio() const {
303+
return reassign_ratio_;
304+
}
305+
298306
constexpr auto feature_type() const {
299307
return feature_datatype_;
300308
}
@@ -380,6 +388,7 @@ class IndexIVFPQ {
380388
[[nodiscard]] virtual uint64_t num_subspaces() const = 0;
381389
[[nodiscard]] virtual uint64_t max_iterations() const = 0;
382390
[[nodiscard]] virtual float convergence_tolerance() const = 0;
391+
[[nodiscard]] virtual float reassign_ratio() const = 0;
383392
};
384393

385394
/**
@@ -396,13 +405,15 @@ class IndexIVFPQ {
396405
size_t n_list,
397406
size_t num_subspaces,
398407
size_t max_iterations,
399-
size_t convergence_tolerance,
408+
float convergence_tolerance,
409+
float reassign_ratio,
400410
std::optional<TemporalPolicy> temporal_policy)
401411
: impl_index_(
402412
n_list,
403413
num_subspaces,
404414
max_iterations,
405415
convergence_tolerance,
416+
reassign_ratio,
406417
temporal_policy) {
407418
}
408419

@@ -532,6 +543,10 @@ class IndexIVFPQ {
532543
return impl_index_.convergence_tolerance();
533544
}
534545

546+
float reassign_ratio() const override {
547+
return impl_index_.reassign_ratio();
548+
}
549+
535550
private:
536551
/**
537552
* @brief Instance of the concrete class.
@@ -540,7 +555,7 @@ class IndexIVFPQ {
540555
};
541556

542557
// clang-format off
543-
using constructor_function = std::function<std::unique_ptr<index_base>(size_t, size_t, size_t, float, std::optional<TemporalPolicy>)>;
558+
using constructor_function = std::function<std::unique_ptr<index_base>(size_t, size_t, size_t, float, float, std::optional<TemporalPolicy>)>;
544559
using table_type = std::map<std::tuple<tiledb_datatype_t, tiledb_datatype_t, tiledb_datatype_t>, constructor_function>;
545560
static const table_type dispatch_table;
546561

@@ -558,6 +573,7 @@ class IndexIVFPQ {
558573
size_t num_subspaces_{16};
559574
size_t max_iterations_{2};
560575
float convergence_tolerance_{0.000025f};
576+
float reassign_ratio_{0.075f};
561577
tiledb_datatype_t feature_datatype_{TILEDB_ANY};
562578
tiledb_datatype_t id_datatype_{TILEDB_ANY};
563579
tiledb_datatype_t partitioning_index_datatype_{TILEDB_ANY};
@@ -566,18 +582,18 @@ class IndexIVFPQ {
566582

567583
// clang-format off
568584
const IndexIVFPQ::table_type IndexIVFPQ::dispatch_table = {
569-
{{TILEDB_INT8, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
570-
{{TILEDB_UINT8, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
571-
{{TILEDB_FLOAT32, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
572-
{{TILEDB_INT8, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
573-
{{TILEDB_UINT8, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
574-
{{TILEDB_FLOAT32, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
575-
{{TILEDB_INT8, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
576-
{{TILEDB_UINT8, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
577-
{{TILEDB_FLOAT32, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
578-
{{TILEDB_INT8, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
579-
{{TILEDB_UINT8, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
580-
{{TILEDB_FLOAT32, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, temporal_policy); }},
585+
{{TILEDB_INT8, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
586+
{{TILEDB_UINT8, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
587+
{{TILEDB_FLOAT32, TILEDB_UINT32, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint32_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
588+
{{TILEDB_INT8, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
589+
{{TILEDB_UINT8, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
590+
{{TILEDB_FLOAT32, TILEDB_UINT32, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint32_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
591+
{{TILEDB_INT8, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
592+
{{TILEDB_UINT8, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
593+
{{TILEDB_FLOAT32, TILEDB_UINT64, TILEDB_UINT32}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint64_t, uint32_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
594+
{{TILEDB_INT8, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<int8_t, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
595+
{{TILEDB_UINT8, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<uint8_t, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
596+
{{TILEDB_FLOAT32, TILEDB_UINT64, TILEDB_UINT64}, [](size_t nlist, size_t num_subspaces, size_t max_iterations, float convergence_tolerance, float reassign_ratio, std::optional<TemporalPolicy> temporal_policy) { return std::make_unique<index_impl<ivf_pq_index<float, uint64_t, uint64_t>>>(nlist, num_subspaces, max_iterations, convergence_tolerance, reassign_ratio, temporal_policy); }},
581597
};
582598

583599
const IndexIVFPQ::uri_table_type IndexIVFPQ::uri_dispatch_table = {

src/include/index/ivf_pq_group.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class ivf_pq_group : public base_index_group<index_type> {
8989
size_t num_subspaces = 0)
9090
: Base(ctx, uri, rw, temporal_policy, version, dimensions) {
9191
if (rw == TILEDB_WRITE && !this->exists()) {
92+
// num_clusters and num_subspaces must be set before we call
93+
// create_default_impl().
9294
if (num_clusters == 0) {
9395
throw std::invalid_argument(
9496
"num_clusters must be specified when creating a new group.");
@@ -221,8 +223,7 @@ class ivf_pq_group : public base_index_group<index_type> {
221223
}
222224

223225
/*****************************************************************************
224-
* Getters and setters for PQ related metadata: num_subspaces, sub_dimension,
225-
* bits_per_subspace, num_clusters
226+
* Getters and setters for PQ related metadata
226227
****************************************************************************/
227228
auto get_num_subspaces() const {
228229
return metadata_.num_subspaces_;
@@ -252,6 +253,27 @@ class ivf_pq_group : public base_index_group<index_type> {
252253
metadata_.num_clusters_ = num_clusters;
253254
}
254255

256+
auto get_max_iterations() const {
257+
return metadata_.max_iterations_;
258+
}
259+
auto set_max_iterations(uint64_t max_iterations) {
260+
metadata_.max_iterations_ = max_iterations;
261+
}
262+
263+
auto get_convergence_tolerance() const {
264+
return metadata_.convergence_tolerance_;
265+
}
266+
auto set_convergence_tolerance(float convergence_tolerance) {
267+
metadata_.convergence_tolerance_ = convergence_tolerance;
268+
}
269+
270+
auto get_reassign_ratio() const {
271+
return metadata_.reassign_ratio_;
272+
}
273+
auto set_reassign_ratio(float reassign_ratio) {
274+
metadata_.reassign_ratio_ = reassign_ratio;
275+
}
276+
255277
/*****************************************************************************
256278
* Create a ready-to-use group with default arrays
257279
****************************************************************************/

0 commit comments

Comments
 (0)