Skip to content

Commit 7bbcaa0

Browse files
teo-tsirpanisihnorton
authored andcommitted
Add option to use C++ k-means implementation
Squashed from #147 pick 24f5bb6 Default the shuffled ID and index types of `kmeans_index` to `size_t`. pick 8fe7850 Enable the k-means initialization tests. pick 23bf690 Support specifying the seed when creating a `kmeans_index`. pick ce8745f Avoid randomly choosing the same centroid many times. pick 0674917 Apply some fixes to the superbuild CMake file from Core. pick 960e036 Add default values for tolerance and number of threads in `kmeans_index`. pick c37f20c Start writing the Python kmeans APIs in a separate file. pick 87983b4 Set internal linkage to some utility functions. pick 8e33ac4 Fix more duplicate symbol errors. pick 43ef100 Add a kmeans predict function. pick 115b8f2 Train the kmeans index in the Python wrapper. pick c773e2a Use kmeans_fit in the ingestion code instead of sklearn. pick 455ca20 Fix compile errors and a warning. pick fcf88f3 More refactorings and use `array_to_matrix`. pick f35f100 Fix errors in the ingestion. pick 239a753 Improve a test and diagnostic output. pick 66de269 Always use floats to train kmeans. pick fc5c0cf Add more parameters to `kmeans_fit`. pick 2879be9 Add a test that compares the results of sklearn's and our own kmeans implementation. pick 94643ce Use kmeans_predict instead of sklearn. This removes the sklearn dependency for good. pick 45f2852 Use common options across sklearn's and our kmeans implementations. pick b307de5 Rename `kmeans++` to `k-means++` to match sklearn. pick 584d548 Assert that the score of the our kmeans implementation is smaller than 1.5 times the score of sklearn's. pick a7da424 fix transposed args in kmeans.cc -- add unit test [skip ci] pick 8527303 Test both kmeans++ and random initialization. pick 6575791 Fix formatting and delete commented code. pick 34ddcb5 Make the kmeans test more deterministic. pick 8769d04 Add back the asserts. pick ef38b0b Add an opt-in switch to use sklearn's kmeans implementation.
1 parent e9c6772 commit 7bbcaa0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1670
-2979
lines changed

.github/workflows/build_wheels.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ on:
44
push:
55
branches:
66
- release-*
7-
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string
87
tags:
98
- '*'
109
pull_request:

.github/workflows/ci_python.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,4 @@ jobs:
3535
#pip uninstall -y tiledb.vector_search
3636
#pip install -e .
3737
#pytest
38-
pip install -r test/ipynb/requirements.txt
39-
pytest --nbmake test/ipynb
40-
env:
41-
TILEDB_REST_TOKEN: ${{ secrets.TILEDB_CLOUD_HELPER_VAR }}
4238
shell: bash -el {0}

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ RUN conda config --prepend channels conda-forge
88
# Install mamba for faster installations
99
RUN conda install mamba
1010

11-
RUN mamba install -y -c tiledb 'tiledb>=2.17,<2.18' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"
11+
RUN mamba install -y -c tiledb 'tiledb>=2.16,<2.17' tiledb-py cmake pybind11 pytest c-compiler cxx-compiler ninja openblas-devel "pip>22"
1212

1313
COPY . TileDB-Vector-Search/
1414

README.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,3 @@ development-build instructions. For large new
4242
features, please open an issue to discuss goals and approach in order
4343
to ensure a smooth PR integration and review process. All contributions
4444
must be licensed under the repository's [MIT License](../LICENSE).
45-
46-
# Testing
47-
48-
* Unit tests: `pytest`
49-
* Demo notebooks:
50-
* ```
51-
pip install -r test/ipynb/requirements.txt
52-
pytest --nbmake test/ipynb
53-
```
54-
* Credentials:
55-
* Some tests run on TileDB Cloud using your current environment variable `TILEDB_REST_TOKEN` -- you will need a valid API token for the tests to pass
56-
* For continuous integration, the token is configured for the `unittest` user and all tests should pass

apis/python/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ find_package(pybind11 CONFIG REQUIRED)
4949

5050
set(VSPY_TARGET_NAME _tiledbvspy)
5151

52-
python_add_library(${VSPY_TARGET_NAME} MODULE "src/tiledb/vector_search/module.cc" WITH_SOABI)
52+
python_add_library(${VSPY_TARGET_NAME} MODULE
53+
"src/tiledb/vector_search/module.cc"
54+
"src/tiledb/vector_search/kmeans.cc"
55+
WITH_SOABI)
5356

5457
target_link_libraries(${VSPY_TARGET_NAME}
5558
PRIVATE

apis/python/pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[project]
22
name = "tiledb-vector-search"
3-
dynamic = ["version"]
3+
version = "0.0.10"
4+
#dynamic = ["version"]
45
description = "TileDB Vector Search Python client"
56
license = { text = "MIT" }
67
readme = "README.md"
@@ -18,14 +19,14 @@ classifiers = [
1819
]
1920

2021
dependencies = [
21-
"tiledb-cloud>=0.11",
22-
"tiledb>=0.23.1",
22+
"tiledb-cloud>=0.10.5",
23+
"tiledb>=0.15.2",
2324
"typing-extensions", # for tiledb-cloud indirect, x-ref https://github.com/TileDB-Inc/TileDB-Cloud-Py/pull/428
2425
"scikit-learn",
2526
]
2627

2728
[project.optional-dependencies]
28-
test = ["nbmake", "pytest"]
29+
test = ["pytest"]
2930

3031

3132
[project.urls]
@@ -46,7 +47,7 @@ zip-safe = false
4647

4748
[tool.setuptools_scm]
4849
root = "../.."
49-
write_to = "apis/python/src/tiledb/vector_search/version.py"
50+
#write_to = "apis/python/src/tiledb/vector_search/version.py"
5051

5152
[tool.ruff]
5253
extend-select = ["I"]

apis/python/requirements-py.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy==1.24.3
2-
tiledb-cloud==0.10.24
3-
tiledb==0.23.1
2+
tiledb-cloud==0.10.5
3+
tiledb==0.21.3

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1-
# Re-import mode from cloud.dag
2-
from tiledb.cloud.dag.mode import Mode
3-
41
from . import utils
5-
from .flat_index import FlatIndex
62
from .index import Index
7-
from .ingestion import ingest
83
from .ivf_flat_index import IVFFlatIndex
9-
from .module import (array_to_matrix, ivf_index, ivf_index_tdb, ivf_query,
10-
ivf_query_ram, load_as_array, load_as_matrix,
11-
partition_ivf_index, query_vq_heap, query_vq_nth,
12-
validate_top_k)
13-
from .storage_formats import STORAGE_VERSION, storage_formats
4+
from .flat_index import FlatIndex
5+
from .ingestion import ingest
6+
from .storage_formats import storage_formats, STORAGE_VERSION
7+
from .module import load_as_array
8+
from .module import load_as_matrix
9+
from .module import (
10+
query_vq_heap,
11+
query_vq_nth,
12+
ivf_query,
13+
ivf_query_ram,
14+
validate_top_k,
15+
array_to_matrix,
16+
ivf_index,
17+
ivf_index_tdb,
18+
partition_ivf_index,
19+
)
1420

15-
try:
16-
from tiledb.vector_search.version import version as __version__
17-
except ImportError:
18-
__version__ = "0.0.0.local"
21+
# Re-import mode from cloud.dag
22+
from tiledb.cloud.dag.mode import Mode
1923

2024
__all__ = [
2125
"FlatIndex",
Lines changed: 25 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1-
import json
2-
from typing import Any, Mapping
3-
41
import numpy as np
52

6-
from tiledb.vector_search import index
73
from tiledb.vector_search.module import *
8-
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
9-
storage_formats)
10-
11-
MAX_INT32 = np.iinfo(np.dtype("int32")).max
12-
TILE_SIZE_BYTES = 128000000 # 128MB
13-
INDEX_TYPE = "FLAT"
4+
from tiledb.vector_search.storage_formats import storage_formats
5+
from tiledb.vector_search.index import Index
6+
from typing import Any, Mapping
147

158

16-
class FlatIndex(index.Index):
9+
class FlatIndex(Index):
1710
"""
1811
Open a flat index
1912
@@ -29,51 +22,33 @@ def __init__(
2922
self,
3023
uri: str,
3124
config: Optional[Mapping[str, Any]] = None,
32-
timestamp=None,
33-
**kwargs,
3425
):
35-
super().__init__(uri=uri, config=config, timestamp=timestamp)
36-
self.index_type = INDEX_TYPE
26+
super().__init__(uri=uri, config=config)
27+
self.index_type = "FLAT"
3728
self._index = None
38-
self.db_uri = self.group[
39-
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
40-
+ self.index_version
29+
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri
30+
schema = tiledb.ArraySchema.load(
31+
self.db_uri, ctx=tiledb.Ctx(self.config)
32+
)
33+
self.size = schema.domain.dim(1).domain[1]+1
34+
self._db = load_as_matrix(
35+
self.db_uri,
36+
ctx=self.ctx,
37+
config=config,
38+
)
39+
self.ids_uri = self.group[
40+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
4141
].uri
42-
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
43-
if self.base_size == -1:
44-
self.size = schema.domain.dim(1).domain[1] + 1
42+
if tiledb.array_exists(self.ids_uri, self.ctx):
43+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
4544
else:
46-
self.size = self.base_size
45+
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
4746

48-
self.dtype = np.dtype(self.group.meta.get("dtype", None))
49-
if (
50-
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
51-
in self.group
52-
):
53-
self.ids_uri = self.group[
54-
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
55-
+ self.index_version
56-
].uri
47+
dtype = self.group.meta.get("dtype", None)
48+
if dtype is None:
49+
self.dtype = self._db.dtype
5750
else:
58-
self.ids_uri = ""
59-
if self.size > 0:
60-
self._db = load_as_matrix(
61-
self.db_uri,
62-
ctx=self.ctx,
63-
config=config,
64-
size=self.size,
65-
timestamp=self.base_array_timestamp,
66-
)
67-
if self.dtype is None:
68-
self.dtype = self._db.dtype
69-
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
70-
# that the external_ids were the position of the vector in the array.
71-
if self.ids_uri == "":
72-
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
73-
else:
74-
self._ids = read_vector_u64(
75-
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
76-
)
51+
self.dtype = np.dtype(dtype)
7752

7853
def query_internal(
7954
self,
@@ -96,92 +71,10 @@ def query_internal(
9671
# TODO:
9772
# - typecheck queries
9873
# - add all the options and query strategies
99-
if self.size == 0:
100-
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
101-
(queries.shape[0], k), index.MAX_UINT64
102-
)
10374

10475
assert queries.dtype == np.float32
10576

10677
queries_m = array_to_matrix(np.transpose(queries))
10778
d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)
10879

10980
return np.transpose(np.array(d)), np.transpose(np.array(i))
110-
111-
112-
def create(
113-
uri: str,
114-
dimensions: int,
115-
vector_type: np.dtype,
116-
group_exists: bool = False,
117-
config: Optional[Mapping[str, Any]] = None,
118-
**kwargs,
119-
) -> FlatIndex:
120-
index.create_metadata(
121-
uri=uri,
122-
dimensions=dimensions,
123-
vector_type=vector_type,
124-
index_type=INDEX_TYPE,
125-
group_exists=group_exists,
126-
config=config,
127-
)
128-
with tiledb.scope_ctx(ctx_or_config=config):
129-
group = tiledb.Group(uri, "w")
130-
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
131-
ids_array_name = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
132-
parts_array_name = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
133-
ids_uri = f"{uri}/{ids_array_name}"
134-
parts_uri = f"{uri}/{parts_array_name}"
135-
136-
ids_array_rows_dim = tiledb.Dim(
137-
name="rows",
138-
domain=(0, MAX_INT32),
139-
tile=tile_size,
140-
dtype=np.dtype(np.int32),
141-
)
142-
ids_array_dom = tiledb.Domain(ids_array_rows_dim)
143-
ids_attr = tiledb.Attr(
144-
name="values",
145-
dtype=np.dtype(np.uint64),
146-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
147-
)
148-
ids_schema = tiledb.ArraySchema(
149-
domain=ids_array_dom,
150-
sparse=False,
151-
attrs=[ids_attr],
152-
cell_order="col-major",
153-
tile_order="col-major",
154-
)
155-
tiledb.Array.create(ids_uri, ids_schema)
156-
group.add(ids_uri, name=ids_array_name)
157-
158-
parts_array_rows_dim = tiledb.Dim(
159-
name="rows",
160-
domain=(0, dimensions - 1),
161-
tile=dimensions,
162-
dtype=np.dtype(np.int32),
163-
)
164-
parts_array_cols_dim = tiledb.Dim(
165-
name="cols",
166-
domain=(0, MAX_INT32),
167-
tile=tile_size,
168-
dtype=np.dtype(np.int32),
169-
)
170-
parts_array_dom = tiledb.Domain(parts_array_rows_dim, parts_array_cols_dim)
171-
parts_attr = tiledb.Attr(
172-
name="values",
173-
dtype=vector_type,
174-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
175-
)
176-
parts_schema = tiledb.ArraySchema(
177-
domain=parts_array_dom,
178-
sparse=False,
179-
attrs=[parts_attr],
180-
cell_order="col-major",
181-
tile_order="col-major",
182-
)
183-
tiledb.Array.create(parts_uri, parts_schema)
184-
group.add(parts_uri, name=parts_array_name)
185-
186-
group.close()
187-
return FlatIndex(uri=uri, config=config)

0 commit comments

Comments
 (0)