Skip to content

Commit c394d0b

Browse files
lums658ihnorton
authored andcommitted
Updates for C++ k-means integration
Squashed from: #147 pick 697c481 Parameterize min heap with comparison function [skip ci] pick e5a797a Debug zero cluster fix [skip ci] pick d085f66 Uncomment debug statements [skip ci] pick 6b07f17 Initial partition-equalization pick 9867c90 Updates for kmeans and kmeans++ pick ff71e87 Small update pick 9176a54 clean up warnings, clang-format pick e5e3690 Add documentation, update unit tests pick 76b2fe8 Replace std::abs<size_t> with std::labs pick 28651e6 Supress std::labs warnings pick f156dfa Small bug fix in predict pick d8b0871 clang-format [skip ci] pick ae5fca4 Add documentation, verify build
1 parent 7bbcaa0 commit c394d0b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1578
-777
lines changed

.github/workflows/build_wheels.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
push:
55
branches:
66
- release-*
7+
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string
78
tags:
89
- '*'
910
pull_request:

.github/workflows/ci_python.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ jobs:
3535
#pip uninstall -y tiledb.vector_search
3636
#pip install -e .
3737
#pytest
38+
pip install -r test/ipynb/requirements.txt
39+
pytest --nbmake test/ipynb
40+
env:
41+
TILEDB_REST_TOKEN: ${{ secrets.TILEDB_CLOUD_HELPER_VAR }}
3842
shell: bash -el {0}

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ RUN conda config --prepend channels conda-forge
88
# Install mamba for faster installations
99
RUN conda install mamba
1010

11-
RUN mamba install -y -c tiledb 'tiledb>=2.16,<2.17' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"
11+
RUN mamba install -y -c tiledb 'tiledb>=2.17,<2.18' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"
1212

1313
COPY . TileDB-Vector-Search/
1414

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ development-build instructions. For large new
4242
features, please open an issue to discuss goals and approach in order
4343
to ensure a smooth PR integration and review process. All contributions
4444
must be licensed under the repository's [MIT License](../LICENSE).
45+
46+
# Testing
47+
48+
* Unit tests: `pytest`
49+
* Demo notebooks:
50+
* ```
51+
pip install -r test/ipynb/requirements.txt
52+
pytest --nbmake test/ipynb
53+
```
54+
* Credentials:
55+
* Some tests run on TileDB Cloud using your current environment variable `TILEDB_REST_TOKEN` -- you will need a valid API token for the tests to pass
56+
* For continuous integration, the token is configured for the `unittest` user and all tests should pass

apis/python/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "tiledb-vector-search"
3-
version = "0.0.10"
3+
version = "0.0.14"
44
#dynamic = ["version"]
55
description = "TileDB Vector Search Python client"
66
license = { text = "MIT" }
@@ -26,7 +26,7 @@ dependencies = [
2626
]
2727

2828
[project.optional-dependencies]
29-
test = ["pytest"]
29+
test = ["nbmake", "pytest"]
3030

3131

3232
[project.urls]

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@ def __init__(
3636
ctx=self.ctx,
3737
config=config,
3838
)
39-
self.ids_uri = self.group[
40-
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
41-
].uri
42-
if tiledb.array_exists(self.ids_uri, self.ctx):
39+
40+
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
41+
# that the external_ids were the position of the vector in the array.
42+
if storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version in self.group:
43+
self.ids_uri = self.group[
44+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
45+
].uri
4346
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
4447
else:
4548
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import concurrent.futures as futures
2+
import os
13
import numpy as np
24
import sys
35

@@ -20,6 +22,7 @@ class Index:
2022
config: Optional[Mapping[str, Any]]
2123
config dictionary, defaults to None
2224
"""
25+
2326
def __init__(
2427
self,
2528
uri: str,
@@ -36,16 +39,28 @@ def __init__(
3639
self.storage_version = self.group.meta.get("storage_version", "0.1")
3740
self.update_arrays_uri = None
3841
self.index_version = self.group.meta.get("index_version", "")
39-
42+
self.thread_executor = futures.ThreadPoolExecutor()
4043

4144
def query(self, queries: np.ndarray, k, **kwargs):
42-
updated_ids = set(self.read_updated_ids())
43-
retrieval_k = k
44-
if len(updated_ids) > 0:
45-
retrieval_k = 2*k
46-
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
4745
if self.update_arrays_uri is None:
48-
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
46+
return self.query_internal(queries, k, **kwargs)
47+
48+
# Query with updates
49+
# Perform the queries in parallel
50+
retrieval_k = 2 * k
51+
kwargs["nthreads"] = int(os.cpu_count() / 2)
52+
future = self.thread_executor.submit(
53+
Index.query_additions,
54+
queries,
55+
k,
56+
self.dtype,
57+
self.update_arrays_uri,
58+
int(os.cpu_count() / 2),
59+
)
60+
internal_results_d, internal_results_i = self.query_internal(
61+
queries, retrieval_k, **kwargs
62+
)
63+
addition_results_d, addition_results_i, updated_ids = future.result()
4964

5065
# Filter updated vectors
5166
query_id = 0
@@ -55,119 +70,137 @@ def query(self, queries: np.ndarray, k, **kwargs):
5570
if res in updated_ids:
5671
internal_results_d[query_id, res_id] = MAX_FLOAT_32
5772
internal_results_i[query_id, res_id] = MAX_UINT64
73+
if (
74+
internal_results_d[query_id, res_id] == 0
75+
and internal_results_i[query_id, res_id] == 0
76+
):
77+
internal_results_d[query_id, res_id] = MAX_FLOAT_32
78+
internal_results_i[query_id, res_id] = MAX_UINT64
5879
res_id += 1
5980
query_id += 1
6081
sort_index = np.argsort(internal_results_d, axis=1)
6182
internal_results_d = np.take_along_axis(internal_results_d, sort_index, axis=1)
6283
internal_results_i = np.take_along_axis(internal_results_i, sort_index, axis=1)
6384

6485
# Merge update results
65-
addition_results_d, addition_results_i = self.query_additions(queries, k)
6686
if addition_results_d is None:
6787
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
6888

6989
query_id = 0
7090
for query in addition_results_d:
7191
res_id = 0
7292
for res in query:
73-
if addition_results_d[query_id, res_id] == 0 and addition_results_i[query_id, res_id] == 0:
93+
if (
94+
addition_results_d[query_id, res_id] == 0
95+
and addition_results_i[query_id, res_id] == 0
96+
):
7497
addition_results_d[query_id, res_id] = MAX_FLOAT_32
7598
addition_results_i[query_id, res_id] = MAX_UINT64
7699
res_id += 1
77100
query_id += 1
78101

79-
80102
results_d = np.hstack((internal_results_d, addition_results_d))
81103
results_i = np.hstack((internal_results_i, addition_results_i))
82104
sort_index = np.argsort(results_d, axis=1)
83105
results_d = np.take_along_axis(results_d, sort_index, axis=1)
84106
results_i = np.take_along_axis(results_i, sort_index, axis=1)
85107
return results_d[:, 0:k], results_i[:, 0:k]
86108

87-
def query_internal(self, queries: np.ndarray, k, **kwargs):
88-
raise NotImplementedError
89-
90-
def query_additions(self, queries: np.ndarray, k):
109+
@staticmethod
110+
def query_additions(
111+
queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
112+
):
91113
assert queries.dtype == np.float32
92-
additions_vectors, additions_external_ids = self.read_additions()
114+
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
115+
update_arrays_uri
116+
)
93117
if additions_vectors is None:
94-
return None, None
118+
return None, None, updated_ids
119+
95120
queries_m = array_to_matrix(np.transpose(queries))
96121
d, i = query_vq_heap_pyarray(
97-
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
122+
array_to_matrix(np.transpose(additions_vectors).astype(dtype)),
98123
queries_m,
99124
StdVector_u64(additions_external_ids),
100125
k,
101-
8)
102-
return np.transpose(np.array(d)), np.transpose(np.array(i))
126+
nthreads,
127+
)
128+
return np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids
129+
130+
@staticmethod
131+
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
132+
if update_arrays_uri is None:
133+
return None, None, np.array([], np.uint64)
134+
updates_array = tiledb.open(update_arrays_uri, mode="r")
135+
q = updates_array.query(attrs=("vector",), coords=True)
136+
data = q[:]
137+
updates_array.close()
138+
updated_ids = data["external_id"]
139+
additions_filter = [len(item) > 0 for item in data["vector"]]
140+
if len(data["external_id"][additions_filter]) > 0:
141+
return (
142+
np.vstack(data["vector"][additions_filter]),
143+
data["external_id"][additions_filter],
144+
updated_ids
145+
)
146+
else:
147+
return None, None, updated_ids
148+
149+
def query_internal(self, queries: np.ndarray, k, **kwargs):
150+
raise NotImplementedError
103151

104152
def update(self, vector: np.array, external_id: np.uint64):
105153
updates_array = self.open_updates_array()
106-
vectors = np.empty((1), dtype='O')
154+
vectors = np.empty((1), dtype="O")
107155
vectors[0] = vector
108-
updates_array[external_id] = {'vector': vectors}
156+
updates_array[external_id] = {"vector": vectors}
109157
updates_array.close()
110158
self.consolidate_update_fragments()
111159

112160
def update_batch(self, vectors: np.ndarray, external_ids: np.array):
113161
updates_array = self.open_updates_array()
114-
updates_array[external_ids] = {'vector': vectors}
162+
updates_array[external_ids] = {"vector": vectors}
115163
updates_array.close()
116164
self.consolidate_update_fragments()
117165

118166
def delete(self, external_id: np.uint64):
119167
updates_array = self.open_updates_array()
120-
deletes = np.empty((1), dtype='O')
168+
deletes = np.empty((1), dtype="O")
121169
deletes[0] = np.array([], dtype=self.dtype)
122-
updates_array[external_id] = {'vector': deletes}
170+
updates_array[external_id] = {"vector": deletes}
123171
updates_array.close()
124172
self.consolidate_update_fragments()
125173

126174
def delete_batch(self, external_ids: np.array):
127175
updates_array = self.open_updates_array()
128-
deletes = np.empty((len(external_ids)), dtype='O')
176+
deletes = np.empty((len(external_ids)), dtype="O")
129177
for i in range(len(external_ids)):
130178
deletes[i] = np.array([], dtype=self.dtype)
131-
updates_array[external_ids] = {'vector': deletes}
179+
updates_array[external_ids] = {"vector": deletes}
132180
updates_array.close()
133181
self.consolidate_update_fragments()
134182

135183
def consolidate_update_fragments(self):
136184
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
137-
if(len(fragments_info) > 10):
185+
if len(fragments_info) > 10:
138186
tiledb.consolidate(self.update_arrays_uri)
139187
tiledb.vacuum(self.update_arrays_uri)
140188

141189
def get_updates_uri(self):
142190
return self.update_arrays_uri
143191

144-
def read_additions(self) -> (np.ndarray, np.array):
145-
if self.update_arrays_uri is None:
146-
return None, None
147-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
148-
q = updates_array.query(attrs=('vector',), coords=True)
149-
data = q[:]
150-
additions_filter = [len(item) > 0 for item in data["vector"]]
151-
if len(data["external_id"][additions_filter]) > 0:
152-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
153-
else:
154-
return None, None
155-
def read_updated_ids(self) -> np.array:
156-
if self.update_arrays_uri is None:
157-
return np.array([], np.uint64)
158-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
159-
q = updates_array.query(attrs=('vector',), coords=True)
160-
data = q[:]
161-
return data["external_id"]
162-
163192
def open_updates_array(self):
164193
if self.update_arrays_uri is None:
165-
updates_array_name = storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]
194+
updates_array_name = storage_formats[self.storage_version][
195+
"UPDATES_ARRAY_NAME"
196+
]
166197
updates_array_uri = f"{self.group.uri}/{updates_array_name}"
167198
if tiledb.array_exists(updates_array_uri):
168199
raise RuntimeError(f"Array {updates_array_uri} already exists.")
169200
external_id_dim = tiledb.Dim(
170-
name="external_id", domain=(0, MAX_UINT64-1), dtype=np.dtype(np.uint64)
201+
name="external_id",
202+
domain=(0, MAX_UINT64 - 1),
203+
dtype=np.dtype(np.uint64),
171204
)
172205
dom = tiledb.Domain(external_id_dim)
173206
vector_attr = tiledb.Attr(name="vector", dtype=self.dtype, var=True)
@@ -188,13 +221,30 @@ def open_updates_array(self):
188221

189222
def consolidate_updates(self):
190223
from tiledb.vector_search.ingestion import ingest
224+
191225
new_index = ingest(
192226
index_type=self.index_type,
193227
index_uri=self.uri,
194228
size=self.size,
195229
source_uri=self.db_uri,
196230
external_ids_uri=self.ids_uri,
197-
updates_uri=self.update_arrays_uri
231+
updates_uri=self.update_arrays_uri,
198232
)
199233
tiledb.Array.delete_array(self.update_arrays_uri)
234+
self.group.close()
235+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
236+
self.group.remove(self.update_arrays_uri)
237+
self.group.close()
200238
return new_index
239+
240+
@staticmethod
241+
def delete_index(uri, config):
242+
try:
243+
group = tiledb.Group(uri, "m", config=config)
244+
except tiledb.TileDBError as err:
245+
message = str(err)
246+
if "group does not exist" in message:
247+
return
248+
else:
249+
raise err
250+
group.delete()

0 commit comments

Comments
 (0)