Skip to content

Commit 589958e

Browse files
Fix BATCH execution error (#351)
Added tests for BATCH execution in cloud This was causing a pickling error during BATCH execution. ``` Traceback (most recent call last): File "/opt/conda/lib/python3.9/site-packages/tdbudf/batch_udf_main.py", line 339, in real_main result = udf(*args, **kwargs) File "/Users/npapa/miniforge3/envs/tiledb_vs_8_arm/lib/python3.9/site-packages/tiledb/vector_search/object_api/embeddings_ingestion.py", line 432, in ingest_embeddings File "/opt/conda/lib/python3.9/site-packages/tiledb/cloud/dag/dag.py", line 1162, in compute self._batch_taskgraph = self._build_batch_taskgraph() File "/opt/conda/lib/python3.9/site-packages/tiledb/cloud/dag/dag.py", line 1534, in _build_batch_taskgraph kwargs["executable_code"] = codecs.PickleCodec.encode_base64(func) File "/opt/conda/lib/python3.9/site-packages/tiledb/cloud/_results/codecs.py", line 54, in encode_base64 data_bytes = cls.encode(obj) File "/opt/conda/lib/python3.9/site-packages/tiledb/cloud/_results/codecs.py", line 151, in encode return cloudpickle.dumps(obj, protocol=_PICKLE_PROTOCOL) File "/opt/conda/lib/python3.9/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps cp.dump(obj) File "/opt/conda/lib/python3.9/site-packages/cloudpickle/cloudpickle_fast.py", line 632, in dump return Pickler.dump(self, obj) TypeError: cannot pickle 'FilterList' object ```
1 parent 256eb5e commit 589958e

File tree

4 files changed

+178
-13
lines changed

4 files changed

+178
-13
lines changed

apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def install_extra_driver_modules():
154154
import numpy as np
155155

156156
import tiledb
157+
from tiledb.vector_search.object_api import ObjectIndex
158+
from tiledb.vector_search.storage_formats import storage_formats
157159

158160
def instantiate_object(code, class_name, **kwargs):
159161
import importlib.util
@@ -178,7 +180,7 @@ def instantiate_object(code, class_name, **kwargs):
178180
return class_(**kwargs)
179181

180182
logger = setup(config, verbose)
181-
obj_index = object_index.ObjectIndex(
183+
obj_index = ObjectIndex(
182184
object_index_uri,
183185
config=config,
184186
environment_variables=environment_variables,

apis/python/src/tiledb/vector_search/object_api/object_index.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,18 @@ def query(
223223
def update_object_reader(
224224
self,
225225
object_reader: ObjectReader,
226+
config: Optional[Mapping[str, Any]] = None,
226227
):
227-
self.object_reader = object_reader
228-
self.object_reader_source_code = get_source_code(object_reader)
229-
self.object_reader_class_name = object_reader.__class__.__name__
230-
self.object_reader_kwargs = json.dumps(object_reader.init_kwargs())
231-
group = tiledb.Group(self.uri, "w")
232-
group.meta["object_reader_source_code"] = self.object_reader_source_code
233-
group.meta["object_reader_class_name"] = self.object_reader_class_name
234-
group.meta["object_reader_kwargs"] = self.object_reader_kwargs
235-
group.close()
228+
with tiledb.scope_ctx(ctx_or_config=config):
229+
self.object_reader = object_reader
230+
self.object_reader_source_code = get_source_code(object_reader)
231+
self.object_reader_class_name = object_reader.__class__.__name__
232+
self.object_reader_kwargs = json.dumps(object_reader.init_kwargs())
233+
group = tiledb.Group(self.uri, "w")
234+
group.meta["object_reader_source_code"] = self.object_reader_source_code
235+
group.meta["object_reader_class_name"] = self.object_reader_class_name
236+
group.meta["object_reader_kwargs"] = self.object_reader_kwargs
237+
group.close()
236238

237239
def create_embeddings_partitioned_array(
238240
self,

apis/python/test/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
import tiledb
9+
from tiledb.cloud import groups
910
from tiledb.vector_search.storage_formats import STORAGE_VERSION
1011
from tiledb.vector_search.storage_formats import storage_formats
1112

@@ -363,3 +364,31 @@ def quantize_embeddings_int8(
363364
starts = ranges[0, :]
364365
steps = (ranges[1, :] - ranges[0, :]) / 255
365366
return ((embeddings - starts) / steps - 128).astype(np.int8)
367+
368+
369+
def setUpCloudToken():
370+
token = os.getenv("TILEDB_REST_TOKEN")
371+
if os.getenv("TILEDB_CLOUD_HELPER_VAR"):
372+
token = os.getenv("TILEDB_CLOUD_HELPER_VAR")
373+
tiledb.cloud.login(token=token)
374+
375+
376+
def create_cloud_uri(name):
377+
namespace, storage_path, _ = groups._default_ns_path_cred()
378+
storage_path = storage_path.replace("//", "/").replace("/", "//", 1)
379+
rand_name = random_name("vector_search")
380+
test_path = f"tiledb://{namespace}/{storage_path}/{rand_name}"
381+
return f"{test_path}/{name}"
382+
383+
384+
def delete_uri(uri, config):
385+
with tiledb.scope_ctx(ctx_or_config=config):
386+
try:
387+
group = tiledb.Group(uri, "m")
388+
except tiledb.TileDBError as err:
389+
message = str(err)
390+
if "does not exist" in message:
391+
return
392+
else:
393+
raise err
394+
group.delete(recursive=True)

apis/python/test/test_object_index.py

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import tiledb
6+
from tiledb.cloud.dag import Mode
67
from tiledb.vector_search.embeddings import ObjectEmbedding
78
from tiledb.vector_search.object_api import object_index
89
from tiledb.vector_search.object_readers import ObjectPartition
@@ -142,9 +143,9 @@ def read_objects_by_external_ids(self, ids: List[int]) -> OrderedDict:
142143
return {"object": objects, "external_id": external_ids}
143144

144145

145-
def evaluate_query(index_uri, query_kwargs, dim_id, vector_dim_offset):
146+
def evaluate_query(index_uri, query_kwargs, dim_id, vector_dim_offset, config=None):
146147
v_id = dim_id - vector_dim_offset
147-
index = object_index.ObjectIndex(uri=index_uri)
148+
index = object_index.ObjectIndex(uri=index_uri, config=config)
148149
distances, objects, metadata = index.query(
149150
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs
150151
)
@@ -188,7 +189,9 @@ def df_filter(row):
188189
object_ids, np.array([v_id, v_id + 1, v_id + 2, v_id + 3, v_id + 4])
189190
)
190191

191-
index = object_index.ObjectIndex(uri=index_uri, load_metadata_in_memory=False)
192+
index = object_index.ObjectIndex(
193+
uri=index_uri, load_metadata_in_memory=False, config=config
194+
)
192195
distances, objects, metadata = index.query(
193196
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=5, **query_kwargs
194197
)
@@ -296,6 +299,135 @@ def test_object_index_ivf_flat(tmp_path):
296299
)
297300

298301

302+
def test_object_index_ivf_flat_cloud(tmp_path):
303+
from common import create_cloud_uri
304+
from common import delete_uri
305+
from common import setUpCloudToken
306+
307+
setUpCloudToken()
308+
config = tiledb.cloud.Config().dict()
309+
index_uri = create_cloud_uri("object_index_ivf_flat")
310+
worker_resources = {"cpu": "1", "memory": "2Gi"}
311+
reader = TestReader(
312+
object_id_start=0,
313+
object_id_end=1000,
314+
vector_dim_offset=0,
315+
)
316+
embedding = TestEmbedding()
317+
318+
index = object_index.create(
319+
uri=index_uri,
320+
index_type="IVF_FLAT",
321+
object_reader=reader,
322+
embedding=embedding,
323+
config=config,
324+
)
325+
326+
# Check initial ingestion
327+
index.update_index(
328+
embeddings_generation_driver_mode=Mode.BATCH,
329+
embeddings_generation_mode=Mode.BATCH,
330+
vector_indexing_mode=Mode.BATCH,
331+
workers=2,
332+
worker_resources=worker_resources,
333+
driver_resources=worker_resources,
334+
kmeans_resources=worker_resources,
335+
ingest_resources=worker_resources,
336+
consolidate_partition_resources=worker_resources,
337+
objects_per_partition=500,
338+
partitions=10,
339+
config=config,
340+
)
341+
evaluate_query(
342+
index_uri=index_uri,
343+
query_kwargs={"nprobe": 10},
344+
dim_id=42,
345+
vector_dim_offset=0,
346+
config=config,
347+
)
348+
# Check that updating the same data doesn't create duplicates
349+
index.update_index(
350+
embeddings_generation_driver_mode=Mode.BATCH,
351+
embeddings_generation_mode=Mode.BATCH,
352+
vector_indexing_mode=Mode.BATCH,
353+
workers=2,
354+
worker_resources=worker_resources,
355+
driver_resources=worker_resources,
356+
kmeans_resources=worker_resources,
357+
ingest_resources=worker_resources,
358+
consolidate_partition_resources=worker_resources,
359+
objects_per_partition=500,
360+
partitions=10,
361+
config=config,
362+
)
363+
evaluate_query(
364+
index_uri=index_uri,
365+
query_kwargs={"nprobe": 10},
366+
dim_id=42,
367+
vector_dim_offset=0,
368+
config=config,
369+
)
370+
371+
# Add new data with a new reader
372+
reader = TestReader(
373+
object_id_start=1000,
374+
object_id_end=2000,
375+
vector_dim_offset=0,
376+
)
377+
index.update_object_reader(reader, config=config)
378+
index.update_index(
379+
embeddings_generation_driver_mode=Mode.BATCH,
380+
embeddings_generation_mode=Mode.BATCH,
381+
vector_indexing_mode=Mode.BATCH,
382+
workers=2,
383+
worker_resources=worker_resources,
384+
driver_resources=worker_resources,
385+
kmeans_resources=worker_resources,
386+
ingest_resources=worker_resources,
387+
consolidate_partition_resources=worker_resources,
388+
objects_per_partition=500,
389+
partitions=10,
390+
config=config,
391+
)
392+
evaluate_query(
393+
index_uri=index_uri,
394+
query_kwargs={"nprobe": 10},
395+
dim_id=1042,
396+
vector_dim_offset=0,
397+
config=config,
398+
)
399+
400+
# Check overwritting existing data
401+
reader = TestReader(
402+
object_id_start=1000,
403+
object_id_end=2000,
404+
vector_dim_offset=1000,
405+
)
406+
index.update_object_reader(reader, config=config)
407+
index.update_index(
408+
embeddings_generation_driver_mode=Mode.BATCH,
409+
embeddings_generation_mode=Mode.BATCH,
410+
vector_indexing_mode=Mode.BATCH,
411+
workers=2,
412+
worker_resources=worker_resources,
413+
driver_resources=worker_resources,
414+
kmeans_resources=worker_resources,
415+
ingest_resources=worker_resources,
416+
consolidate_partition_resources=worker_resources,
417+
objects_per_partition=500,
418+
partitions=10,
419+
config=config,
420+
)
421+
evaluate_query(
422+
index_uri=index_uri,
423+
query_kwargs={"nprobe": 10},
424+
dim_id=2042,
425+
vector_dim_offset=1000,
426+
config=config,
427+
)
428+
delete_uri(index_uri, config)
429+
430+
299431
def test_object_index_flat(tmp_path):
300432
reader = TestReader(
301433
object_id_start=0,

0 commit comments

Comments
 (0)