Skip to content

Commit bcfdaa1

Browse files
committed
Sync C++ kmeans branch with main and fix errors
1 parent c394d0b commit bcfdaa1

28 files changed

+1886
-781
lines changed

apis/python/pyproject.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[project]
22
name = "tiledb-vector-search"
3-
version = "0.0.14"
4-
#dynamic = ["version"]
3+
dynamic = ["version"]
54
description = "TileDB Vector Search Python client"
65
license = { text = "MIT" }
76
readme = "README.md"
@@ -19,8 +18,8 @@ classifiers = [
1918
]
2019

2120
dependencies = [
22-
"tiledb-cloud>=0.10.5",
23-
"tiledb>=0.15.2",
21+
"tiledb-cloud>=0.11",
22+
"tiledb>=0.23.1",
2423
"typing-extensions", # for tiledb-cloud indirect, x-ref https://github.com/TileDB-Inc/TileDB-Cloud-Py/pull/428
2524
"scikit-learn",
2625
]
@@ -47,7 +46,7 @@ zip-safe = false
4746

4847
[tool.setuptools_scm]
4948
root = "../.."
50-
#write_to = "apis/python/src/tiledb/vector_search/version.py"
49+
write_to = "apis/python/src/tiledb/vector_search/version.py"
5150

5251
[tool.ruff]
5352
extend-select = ["I"]

apis/python/requirements-py.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy==1.24.3
2-
tiledb-cloud==0.10.5
3-
tiledb==0.21.3
2+
tiledb-cloud==0.10.24
3+
tiledb==0.23.1

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1+
# Re-import mode from cloud.dag
2+
from tiledb.cloud.dag.mode import Mode
3+
14
from . import utils
2-
from .index import Index
3-
from .ivf_flat_index import IVFFlatIndex
45
from .flat_index import FlatIndex
6+
from .index import Index
57
from .ingestion import ingest
6-
from .storage_formats import storage_formats, STORAGE_VERSION
7-
from .module import load_as_array
8-
from .module import load_as_matrix
9-
from .module import (
10-
query_vq_heap,
11-
query_vq_nth,
12-
ivf_query,
13-
ivf_query_ram,
14-
validate_top_k,
15-
array_to_matrix,
16-
ivf_index,
17-
ivf_index_tdb,
18-
partition_ivf_index,
19-
)
8+
from .ivf_flat_index import IVFFlatIndex
9+
from .module import (array_to_matrix, ivf_index, ivf_index_tdb, ivf_query,
10+
ivf_query_ram, load_as_array, load_as_matrix,
11+
partition_ivf_index, query_vq_heap, query_vq_nth,
12+
validate_top_k)
13+
from .storage_formats import STORAGE_VERSION, storage_formats
2014

21-
# Re-import mode from cloud.dag
22-
from tiledb.cloud.dag.mode import Mode
15+
try:
16+
from tiledb.vector_search.version import version as __version__
17+
except ImportError:
18+
__version__ = "0.0.0.local"
2319

2420
__all__ = [
2521
"FlatIndex",
Lines changed: 132 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import json
2+
from typing import Any, Mapping
3+
14
import numpy as np
25

6+
from tiledb.vector_search import index
37
from tiledb.vector_search.module import *
4-
from tiledb.vector_search.storage_formats import storage_formats
5-
from tiledb.vector_search.index import Index
6-
from typing import Any, Mapping
8+
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
9+
storage_formats)
710

11+
MAX_INT32 = np.iinfo(np.dtype("int32")).max
12+
TILE_SIZE_BYTES = 128000000 # 128MB
13+
INDEX_TYPE = "FLAT"
814

9-
class FlatIndex(Index):
15+
16+
class FlatIndex(index.Index):
1017
"""
1118
Open a flat index
1219
@@ -22,36 +29,51 @@ def __init__(
2229
self,
2330
uri: str,
2431
config: Optional[Mapping[str, Any]] = None,
32+
timestamp=None,
33+
**kwargs,
2534
):
26-
super().__init__(uri=uri, config=config)
27-
self.index_type = "FLAT"
35+
super().__init__(uri=uri, config=config, timestamp=timestamp)
36+
self.index_type = INDEX_TYPE
2837
self._index = None
29-
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri
30-
schema = tiledb.ArraySchema.load(
31-
self.db_uri, ctx=tiledb.Ctx(self.config)
32-
)
33-
self.size = schema.domain.dim(1).domain[1]+1
34-
self._db = load_as_matrix(
35-
self.db_uri,
36-
ctx=self.ctx,
37-
config=config,
38-
)
38+
self.db_uri = self.group[
39+
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
40+
+ self.index_version
41+
].uri
42+
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
43+
if self.base_size == -1:
44+
self.size = schema.domain.dim(1).domain[1] + 1
45+
else:
46+
self.size = self.base_size
3947

40-
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
41-
# that the external_ids were the position of the vector in the array.
42-
if storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version in self.group:
48+
self.dtype = np.dtype(self.group.meta.get("dtype", None))
49+
if (
50+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
51+
in self.group
52+
):
4353
self.ids_uri = self.group[
44-
storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
54+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
55+
+ self.index_version
4556
].uri
46-
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
47-
else:
48-
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
49-
50-
dtype = self.group.meta.get("dtype", None)
51-
if dtype is None:
52-
self.dtype = self._db.dtype
5357
else:
54-
self.dtype = np.dtype(dtype)
58+
self.ids_uri = ""
59+
if self.size > 0:
60+
self._db = load_as_matrix(
61+
self.db_uri,
62+
ctx=self.ctx,
63+
config=config,
64+
size=self.size,
65+
timestamp=self.base_array_timestamp,
66+
)
67+
if self.dtype is None:
68+
self.dtype = self._db.dtype
69+
# Check for existence of ids array. Previous versions were not using external_ids in the ingestion assuming
70+
# that the external_ids were the position of the vector in the array.
71+
if self.ids_uri == "":
72+
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
73+
else:
74+
self._ids = read_vector_u64(
75+
self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
76+
)
5577

5678
def query_internal(
5779
self,
@@ -74,10 +96,92 @@ def query_internal(
7496
# TODO:
7597
# - typecheck queries
7698
# - add all the options and query strategies
99+
if self.size == 0:
100+
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
101+
(queries.shape[0], k), index.MAX_UINT64
102+
)
77103

78104
assert queries.dtype == np.float32
79105

80106
queries_m = array_to_matrix(np.transpose(queries))
81107
d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)
82108

83109
return np.transpose(np.array(d)), np.transpose(np.array(i))
110+
111+
112+
def create(
113+
uri: str,
114+
dimensions: int,
115+
vector_type: np.dtype,
116+
group_exists: bool = False,
117+
config: Optional[Mapping[str, Any]] = None,
118+
**kwargs,
119+
) -> FlatIndex:
120+
index.create_metadata(
121+
uri=uri,
122+
dimensions=dimensions,
123+
vector_type=vector_type,
124+
index_type=INDEX_TYPE,
125+
group_exists=group_exists,
126+
config=config,
127+
)
128+
with tiledb.scope_ctx(ctx_or_config=config):
129+
group = tiledb.Group(uri, "w")
130+
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
131+
ids_array_name = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
132+
parts_array_name = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
133+
ids_uri = f"{uri}/{ids_array_name}"
134+
parts_uri = f"{uri}/{parts_array_name}"
135+
136+
ids_array_rows_dim = tiledb.Dim(
137+
name="rows",
138+
domain=(0, MAX_INT32),
139+
tile=tile_size,
140+
dtype=np.dtype(np.int32),
141+
)
142+
ids_array_dom = tiledb.Domain(ids_array_rows_dim)
143+
ids_attr = tiledb.Attr(
144+
name="values",
145+
dtype=np.dtype(np.uint64),
146+
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
147+
)
148+
ids_schema = tiledb.ArraySchema(
149+
domain=ids_array_dom,
150+
sparse=False,
151+
attrs=[ids_attr],
152+
cell_order="col-major",
153+
tile_order="col-major",
154+
)
155+
tiledb.Array.create(ids_uri, ids_schema)
156+
group.add(ids_uri, name=ids_array_name)
157+
158+
parts_array_rows_dim = tiledb.Dim(
159+
name="rows",
160+
domain=(0, dimensions - 1),
161+
tile=dimensions,
162+
dtype=np.dtype(np.int32),
163+
)
164+
parts_array_cols_dim = tiledb.Dim(
165+
name="cols",
166+
domain=(0, MAX_INT32),
167+
tile=tile_size,
168+
dtype=np.dtype(np.int32),
169+
)
170+
parts_array_dom = tiledb.Domain(parts_array_rows_dim, parts_array_cols_dim)
171+
parts_attr = tiledb.Attr(
172+
name="values",
173+
dtype=vector_type,
174+
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
175+
)
176+
parts_schema = tiledb.ArraySchema(
177+
domain=parts_array_dom,
178+
sparse=False,
179+
attrs=[parts_attr],
180+
cell_order="col-major",
181+
tile_order="col-major",
182+
)
183+
tiledb.Array.create(parts_uri, parts_schema)
184+
group.add(parts_uri, name=parts_array_name)
185+
186+
group.close()
187+
return FlatIndex(uri=uri, config=config)

0 commit comments

Comments
 (0)