Skip to content

Commit 62ff91a

Browse files
authored
Allow setting IVF PQ partitions when re-ingesting, fix IVF PQ object index tests (#453)
1 parent fac90e4 commit 62ff91a

File tree

8 files changed

+121
-46
lines changed

8 files changed

+121
-46
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,9 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
450450
are added to the index. It triggers a base index re-indexing, merging the non-consolidated
451451
updates and the rest of the base vectors.
452452
453+
TODO(sc-51202): This throws with a unintuitive error message if update()/delete()/etc. has
454+
not been called.
455+
453456
Parameters
454457
----------
455458
retrain_index: bool

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,7 @@ def ingest_type_erased(
15571557
dimensions: int,
15581558
size: int,
15591559
batch: int,
1560+
partitions: int,
15601561
config: Optional[Mapping[str, Any]] = None,
15611562
verbose: bool = False,
15621563
trace_id: Optional[str] = None,
@@ -1671,16 +1672,17 @@ def ingest_type_erased(
16711672
from tiledb.vector_search import _tiledbvspy as vspy
16721673

16731674
ctx = vspy.Ctx(config)
1675+
data = vspy.FeatureVectorArray(
1676+
ctx, parts_array_uri, ids_array_uri, 0, to_temporal_policy(index_timestamp)
1677+
)
16741678
if index_type == "VAMANA":
16751679
index = vspy.IndexVamana(ctx, index_group_uri)
1680+
index.train(data)
16761681
elif index_type == "IVF_PQ":
16771682
index = vspy.IndexIVFPQ(ctx, index_group_uri)
1683+
index.train(data, partitions)
16781684
else:
16791685
raise ValueError(f"Unsupported index type: {index_type}")
1680-
data = vspy.FeatureVectorArray(
1681-
ctx, parts_array_uri, ids_array_uri, 0, to_temporal_policy(index_timestamp)
1682-
)
1683-
index.train(data)
16841686
index.add(data)
16851687
index.write_index(ctx, index_group_uri, to_temporal_policy(index_timestamp))
16861688

@@ -2270,6 +2272,7 @@ def scale_resources(min_resource, max_resource, max_input_size, input_size):
22702272
dimensions=dimensions,
22712273
size=size,
22722274
batch=input_vectors_batch_size,
2275+
partitions=partitions,
22732276
config=config,
22742277
verbose=verbose,
22752278
trace_id=trace_id,

apis/python/src/tiledb/vector_search/object_api/object_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def query(
125125
self,
126126
query_objects: np.ndarray,
127127
k: int,
128-
query_metadata: OrderedDict = None,
129-
metadata_array_cond: str = None,
130-
metadata_df_filter_fn: str = None,
128+
query_metadata: Optional[OrderedDict] = None,
129+
metadata_array_cond: Optional[str] = None,
130+
metadata_df_filter_fn: Optional[str] = None,
131131
return_objects: bool = True,
132132
return_metadata: bool = True,
133133
**kwargs,

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,11 @@ void init_type_erased_module(py::module_& m) {
480480
})
481481
.def(
482482
"train",
483-
[](IndexIVFPQ& index, const FeatureVectorArray& vectors) {
484-
index.train(vectors);
485-
},
486-
py::arg("vectors"))
483+
[](IndexIVFPQ& index,
484+
const FeatureVectorArray& vectors,
485+
std::optional<size_t> nlist) { index.train(vectors, nlist); },
486+
py::arg("vectors"),
487+
py::arg("nlist") = std::nullopt)
487488
.def(
488489
"add",
489490
[](IndexIVFPQ& index, const FeatureVectorArray& vectors) {

apis/python/test/test_object_index.py

Lines changed: 92 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,64 @@ def read_objects_by_external_ids(self, ids: List[int]) -> OrderedDict:
144144
return {"object": objects, "external_id": external_ids}
145145

146146

147-
def evaluate_query(index_uri, query_kwargs, dim_id, vector_dim_offset, config=None):
147+
def assert_equal(
148+
index_type: str,
149+
ids: np.array,
150+
expected_ids: np.array,
151+
ivf_pq_accuracy_threshold: float,
152+
):
153+
"""
154+
IVF_PQ index has a lower recall rate than other indexes b/c of PQ-encoding, so we need to lower
155+
the threshold.
156+
157+
Parameters
158+
----------
159+
index_type: str
160+
The index type.
161+
ids: np.array
162+
The ids returned by the query.
163+
expected_ids: np.array
164+
The expected ids.
165+
ivf_pq_accuracy_threshold: float
166+
The minimum fraction of expected_ids that must be in ids.
167+
"""
168+
assert len(ids) == len(expected_ids)
169+
if index_type == "IVF_PQ":
170+
matches = np.intersect1d(ids, expected_ids)
171+
assert len(matches) / len(ids) >= ivf_pq_accuracy_threshold
172+
return
173+
174+
assert np.array_equiv(ids, expected_ids)
175+
176+
177+
def evaluate_query(
178+
index_type: str, index_uri, query_kwargs, dim_id, vector_dim_offset, config=None
179+
):
148180
v_id = dim_id - vector_dim_offset
181+
149182
index = object_index.ObjectIndex(uri=index_uri, config=config)
150183
distances, objects, metadata = index.query(
151-
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs
184+
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=21, **query_kwargs
152185
)
153-
assert np.array_equiv(
186+
assert_equal(
187+
index_type,
154188
np.unique(objects["external_id"]),
155-
np.array([v_id - 2, v_id - 1, v_id, v_id + 1, v_id + 2]),
189+
np.array([v_id + i for i in range(-10, 11)]),
190+
ivf_pq_accuracy_threshold=0.8,
156191
)
192+
157193
distances, object_ids = index.query(
158194
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
159-
k=5,
195+
k=21,
160196
return_objects=False,
161197
return_metadata=False,
162198
**query_kwargs,
163199
)
164-
assert np.array_equiv(
165-
np.unique(object_ids), np.array([v_id - 2, v_id - 1, v_id, v_id + 1, v_id + 2])
200+
assert_equal(
201+
index_type,
202+
np.unique(object_ids),
203+
np.array([v_id + i for i in range(-10, 11)]),
204+
ivf_pq_accuracy_threshold=0.8,
166205
)
167206

168207
def df_filter(row):
@@ -171,66 +210,84 @@ def df_filter(row):
171210
distances, objects, metadata = index.query(
172211
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
173212
metadata_df_filter_fn=df_filter,
174-
k=5,
213+
k=21,
175214
**query_kwargs,
176215
)
177-
assert np.array_equiv(
178-
objects["external_id"], np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4])
216+
assert_equal(
217+
index_type,
218+
np.unique(objects["external_id"]),
219+
np.array([v_id + i for i in range(0, 21)]),
220+
ivf_pq_accuracy_threshold=0.8,
179221
)
180222

181223
distances, object_ids = index.query(
182224
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
183225
metadata_df_filter_fn=df_filter,
184-
k=5,
226+
k=21,
185227
return_objects=False,
186228
return_metadata=False,
187229
**query_kwargs,
188230
)
189-
assert np.array_equiv(
190-
object_ids, np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4])
231+
assert_equal(
232+
index_type,
233+
np.unique(object_ids),
234+
np.array([v_id + i for i in range(0, 21)]),
235+
ivf_pq_accuracy_threshold=0.8,
191236
)
192237

193238
index = object_index.ObjectIndex(
194239
uri=index_uri, load_metadata_in_memory=False, config=config
195240
)
196241
distances, objects, metadata = index.query(
197-
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs
242+
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=21, **query_kwargs
198243
)
199-
assert np.array_equiv(
244+
assert_equal(
245+
index_type,
200246
np.unique(objects["external_id"]),
201-
np.array([v_id - 2, v_id - 1, v_id, v_id + 1, v_id + 2]),
247+
np.array([v_id + i for i in range(-10, 11)]),
248+
ivf_pq_accuracy_threshold=0.8,
202249
)
250+
203251
distances, object_ids = index.query(
204252
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
205-
k=5,
253+
k=21,
206254
return_objects=False,
207255
return_metadata=False,
208256
**query_kwargs,
209257
)
210-
assert np.array_equiv(
211-
np.unique(object_ids), np.array([v_id - 2, v_id - 1, v_id, v_id + 1, v_id + 2])
258+
assert_equal(
259+
index_type,
260+
np.unique(object_ids),
261+
np.array([v_id + i for i in range(-10, 11)]),
262+
ivf_pq_accuracy_threshold=0.8,
212263
)
213264

214265
distances, objects, metadata = index.query(
215266
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
216267
metadata_array_cond=f"test_attr >= {dim_id}",
217-
k=5,
268+
k=21,
218269
**query_kwargs,
219270
)
220-
assert np.array_equiv(
221-
objects["external_id"], np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4])
271+
assert_equal(
272+
index_type,
273+
np.unique(objects["external_id"]),
274+
np.array([v_id + i for i in range(0, 21)]),
275+
ivf_pq_accuracy_threshold=0.8,
222276
)
223277

224278
distances, object_ids = index.query(
225279
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])},
226280
metadata_array_cond=f"test_attr >= {dim_id}",
227-
k=5,
281+
k=21,
228282
return_objects=False,
229283
return_metadata=False,
230284
**query_kwargs,
231285
)
232-
assert np.array_equiv(
233-
object_ids, np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4])
286+
assert_equal(
287+
index_type,
288+
np.unique(object_ids),
289+
np.array([v_id + i for i in range(0, 21)]),
290+
ivf_pq_accuracy_threshold=0.8,
234291
)
235292

236293

@@ -256,12 +313,8 @@ def test_object_index(tmp_path):
256313

257314
# Check initial ingestion
258315
index.update_index(partitions=10)
259-
260-
# TODO(SC-48908): Fix IVF_PQ with object index queries and remove.
261-
if index_type == "IVF_PQ":
262-
continue
263-
264316
evaluate_query(
317+
index_type=index_type,
265318
index_uri=index_uri,
266319
query_kwargs={"nprobe": 10, "l_search": 250},
267320
dim_id=42,
@@ -272,6 +325,7 @@ def test_object_index(tmp_path):
272325
index = object_index.ObjectIndex(uri=index_uri)
273326
index.update_index(partitions=10)
274327
evaluate_query(
328+
index_type=index_type,
275329
index_uri=index_uri,
276330
query_kwargs={"nprobe": 10, "l_search": 500},
277331
dim_id=42,
@@ -288,6 +342,7 @@ def test_object_index(tmp_path):
288342
index.update_object_reader(reader)
289343
index.update_index(partitions=10)
290344
evaluate_query(
345+
index_type=index_type,
291346
index_uri=index_uri,
292347
query_kwargs={"nprobe": 10, "l_search": 500},
293348
dim_id=1042,
@@ -304,6 +359,7 @@ def test_object_index(tmp_path):
304359
index.update_object_reader(reader)
305360
index.update_index(partitions=10)
306361
evaluate_query(
362+
index_type=index_type,
307363
index_uri=index_uri,
308364
query_kwargs={"nprobe": 10, "l_search": 500},
309365
dim_id=2042,
@@ -351,6 +407,7 @@ def test_object_index_ivf_flat_cloud(tmp_path):
351407
config=config,
352408
)
353409
evaluate_query(
410+
index_type="IVF_FLAT",
354411
index_uri=index_uri,
355412
query_kwargs={"nprobe": 10},
356413
dim_id=42,
@@ -381,6 +438,7 @@ def test_object_index_ivf_flat_cloud(tmp_path):
381438
config=config,
382439
)
383440
evaluate_query(
441+
index_type="IVF_FLAT",
384442
index_uri=index_uri,
385443
query_kwargs={"nprobe": 10},
386444
dim_id=1042,
@@ -409,6 +467,7 @@ def test_object_index_flat(tmp_path):
409467
# Check initial ingestion
410468
index.update_index()
411469
evaluate_query(
470+
index_type="FLAT",
412471
index_uri=index_uri,
413472
query_kwargs={},
414473
dim_id=42,
@@ -419,6 +478,7 @@ def test_object_index_flat(tmp_path):
419478
index = object_index.ObjectIndex(uri=index_uri)
420479
index.update_index()
421480
evaluate_query(
481+
index_type="FLAT",
422482
index_uri=index_uri,
423483
query_kwargs={},
424484
dim_id=42,
@@ -435,6 +495,7 @@ def test_object_index_flat(tmp_path):
435495
index.update_object_reader(reader)
436496
index.update_index()
437497
evaluate_query(
498+
index_type="FLAT",
438499
index_uri=index_uri,
439500
query_kwargs={},
440501
dim_id=1042,
@@ -451,6 +512,7 @@ def test_object_index_flat(tmp_path):
451512
index.update_object_reader(reader)
452513
index.update_index()
453514
evaluate_query(
515+
index_type="FLAT",
454516
index_uri=index_uri,
455517
query_kwargs={},
456518
dim_id=2042,

src/include/api/ivf_pq_index.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,14 @@ class IndexIVFPQ {
164164

165165
/**
166166
* @brief Train the index based on the given training set.
167-
* @param training_set
168-
* @param init
167+
* @param training_set The training input vectors.
168+
* @param n_list The number of clusters to use in the index. Can be passed to
169+
* override the value we used when we first created the index.
169170
*/
170171
// @todo -- infer feature type from input
171-
void train(const FeatureVectorArray& training_set) {
172+
void train(
173+
const FeatureVectorArray& training_set,
174+
std::optional<size_t> n_list = std::nullopt) {
172175
if (feature_datatype_ == TILEDB_ANY) {
173176
feature_datatype_ = training_set.feature_type();
174177
} else if (feature_datatype_ != training_set.feature_type()) {
@@ -184,6 +187,10 @@ class IndexIVFPQ {
184187
throw std::runtime_error("Unsupported datatype combination");
185188
}
186189

190+
if (n_list.has_value()) {
191+
n_list_ = *n_list;
192+
}
193+
187194
// Create a new index. Note that we may have already loaded an existing
188195
// index by URI. In that case, we have updated our local state (i.e.
189196
// num_subspaces_, etc.), but we should also use the timestamp from that

src/include/api/vamana_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
*
6060
* We support all combinations of the following types for feature, id, and px
6161
* datatypes:
62-
* - feature_type: uint8 or float
62+
* - feature_type: uint8, int8, or float
6363
* - id_type: uint32 or uint64
6464
* - adjacency_row_index_type: uint32 or uint64
6565
*/

src/include/index/ivf_pq_index.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,6 @@ class ivf_pq_index {
713713
* @param training_set_ids IDs for each vector.
714714
*
715715
* @todo Create and write index that is larger than RAM
716-
* @todo Use training_set_ids as the external IDs.
717716
*/
718717
template <
719718
feature_vector_array Array,

0 commit comments

Comments
 (0)