|
3 | 3 | import numpy as np |
4 | 4 |
|
5 | 5 | import tiledb |
| 6 | +from tiledb.cloud.dag import Mode |
6 | 7 | from tiledb.vector_search.embeddings import ObjectEmbedding |
7 | 8 | from tiledb.vector_search.object_api import object_index |
8 | 9 | from tiledb.vector_search.object_readers import ObjectPartition |
@@ -142,9 +143,9 @@ def read_objects_by_external_ids(self, ids: List[int]) -> OrderedDict: |
142 | 143 | return {"object": objects, "external_id": external_ids} |
143 | 144 |
|
144 | 145 |
|
145 | | -def evaluate_query(index_uri, query_kwargs, dim_id, vector_dim_offset): |
| 146 | +def evaluate_query(index_uri, query_kwargs, dim_id, vector_dim_offset, config=None): |
146 | 147 | v_id = dim_id - vector_dim_offset |
147 | | - index = object_index.ObjectIndex(uri=index_uri) |
| 148 | + index = object_index.ObjectIndex(uri=index_uri, config=config) |
148 | 149 | distances, objects, metadata = index.query( |
149 | 150 | {"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs |
150 | 151 | ) |
@@ -188,7 +189,9 @@ def df_filter(row): |
188 | 189 | object_ids, np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4]) |
189 | 190 | ) |
190 | 191 |
|
191 | | - index = object_index.ObjectIndex(uri=index_uri, load_metadata_in_memory=False) |
| 192 | + index = object_index.ObjectIndex( |
| 193 | + uri=index_uri, load_metadata_in_memory=False, config=config |
| 194 | + ) |
192 | 195 | distances, objects, metadata = index.query( |
193 | 196 | {"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs |
194 | 197 | ) |
@@ -296,6 +299,135 @@ def test_object_index_ivf_flat(tmp_path): |
296 | 299 | ) |
297 | 300 |
|
298 | 301 |
|
| 302 | +def test_object_index_ivf_flat_cloud(tmp_path): |
| 303 | + from common import create_cloud_uri |
| 304 | + from common import delete_uri |
| 305 | + from common import setUpCloudToken |
| 306 | + |
| 307 | + setUpCloudToken() |
| 308 | + config = tiledb.cloud.Config().dict() |
| 309 | + index_uri = create_cloud_uri("object_index_ivf_flat") |
| 310 | + worker_resources = {"cpu": "1", "memory": "2Gi"} |
| 311 | + reader = TestReader( |
| 312 | + object_id_start=0, |
| 313 | + object_id_end=1000, |
| 314 | + vector_dim_offset=0, |
| 315 | + ) |
| 316 | + embedding = TestEmbedding() |
| 317 | + |
| 318 | + index = object_index.create( |
| 319 | + uri=index_uri, |
| 320 | + index_type="IVF_FLAT", |
| 321 | + object_reader=reader, |
| 322 | + embedding=embedding, |
| 323 | + config=config, |
| 324 | + ) |
| 325 | + |
| 326 | + # Check initial ingestion |
| 327 | + index.update_index( |
| 328 | + embeddings_generation_driver_mode=Mode.BATCH, |
| 329 | + embeddings_generation_mode=Mode.BATCH, |
| 330 | + vector_indexing_mode=Mode.BATCH, |
| 331 | + workers=2, |
| 332 | + worker_resources=worker_resources, |
| 333 | + driver_resources=worker_resources, |
| 334 | + kmeans_resources=worker_resources, |
| 335 | + ingest_resources=worker_resources, |
| 336 | + consolidate_partition_resources=worker_resources, |
| 337 | + objects_per_partition=500, |
| 338 | + partitions=10, |
| 339 | + config=config, |
| 340 | + ) |
| 341 | + evaluate_query( |
| 342 | + index_uri=index_uri, |
| 343 | + query_kwargs={"nprobe": 10}, |
| 344 | + dim_id=42, |
| 345 | + vector_dim_offset=0, |
| 346 | + config=config, |
| 347 | + ) |
| 348 | + # Check that updating the same data doesn't create duplicates |
| 349 | + index.update_index( |
| 350 | + embeddings_generation_driver_mode=Mode.BATCH, |
| 351 | + embeddings_generation_mode=Mode.BATCH, |
| 352 | + vector_indexing_mode=Mode.BATCH, |
| 353 | + workers=2, |
| 354 | + worker_resources=worker_resources, |
| 355 | + driver_resources=worker_resources, |
| 356 | + kmeans_resources=worker_resources, |
| 357 | + ingest_resources=worker_resources, |
| 358 | + consolidate_partition_resources=worker_resources, |
| 359 | + objects_per_partition=500, |
| 360 | + partitions=10, |
| 361 | + config=config, |
| 362 | + ) |
| 363 | + evaluate_query( |
| 364 | + index_uri=index_uri, |
| 365 | + query_kwargs={"nprobe": 10}, |
| 366 | + dim_id=42, |
| 367 | + vector_dim_offset=0, |
| 368 | + config=config, |
| 369 | + ) |
| 370 | + |
| 371 | + # Add new data with a new reader |
| 372 | + reader = TestReader( |
| 373 | + object_id_start=1000, |
| 374 | + object_id_end=2000, |
| 375 | + vector_dim_offset=0, |
| 376 | + ) |
| 377 | + index.update_object_reader(reader, config=config) |
| 378 | + index.update_index( |
| 379 | + embeddings_generation_driver_mode=Mode.BATCH, |
| 380 | + embeddings_generation_mode=Mode.BATCH, |
| 381 | + vector_indexing_mode=Mode.BATCH, |
| 382 | + workers=2, |
| 383 | + worker_resources=worker_resources, |
| 384 | + driver_resources=worker_resources, |
| 385 | + kmeans_resources=worker_resources, |
| 386 | + ingest_resources=worker_resources, |
| 387 | + consolidate_partition_resources=worker_resources, |
| 388 | + objects_per_partition=500, |
| 389 | + partitions=10, |
| 390 | + config=config, |
| 391 | + ) |
| 392 | + evaluate_query( |
| 393 | + index_uri=index_uri, |
| 394 | + query_kwargs={"nprobe": 10}, |
| 395 | + dim_id=1042, |
| 396 | + vector_dim_offset=0, |
| 397 | + config=config, |
| 398 | + ) |
| 399 | + |
| 400 | + # Check overwritting existing data |
| 401 | + reader = TestReader( |
| 402 | + object_id_start=1000, |
| 403 | + object_id_end=2000, |
| 404 | + vector_dim_offset=1000, |
| 405 | + ) |
| 406 | + index.update_object_reader(reader, config=config) |
| 407 | + index.update_index( |
| 408 | + embeddings_generation_driver_mode=Mode.BATCH, |
| 409 | + embeddings_generation_mode=Mode.BATCH, |
| 410 | + vector_indexing_mode=Mode.BATCH, |
| 411 | + workers=2, |
| 412 | + worker_resources=worker_resources, |
| 413 | + driver_resources=worker_resources, |
| 414 | + kmeans_resources=worker_resources, |
| 415 | + ingest_resources=worker_resources, |
| 416 | + consolidate_partition_resources=worker_resources, |
| 417 | + objects_per_partition=500, |
| 418 | + partitions=10, |
| 419 | + config=config, |
| 420 | + ) |
| 421 | + evaluate_query( |
| 422 | + index_uri=index_uri, |
| 423 | + query_kwargs={"nprobe": 10}, |
| 424 | + dim_id=2042, |
| 425 | + vector_dim_offset=1000, |
| 426 | + config=config, |
| 427 | + ) |
| 428 | + delete_uri(index_uri, config) |
| 429 | + |
| 430 | + |
299 | 431 | def test_object_index_flat(tmp_path): |
300 | 432 | reader = TestReader( |
301 | 433 | object_id_start=0, |
|
0 commit comments