Skip to content

Commit cae0cab

Browse files
Exposes all versions of query functions to Python (#79)
This exposes all versions of query functions to Python. Also did some minor tweeks for ingestion and ran black to format the code.
1 parent ebcaaef commit cae0cab

File tree

11 files changed

+461
-130
lines changed

11 files changed

+461
-130
lines changed

apis/python/setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55

66
def get_cmake_overrides():
7-
87
conf = list()
98

109
tiledb_dir = os.environ.get("TILEDB_DIR", None)

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22
from .ingestion import ingest
33
from .module import load_as_array
44
from .module import load_as_matrix
5-
from .module import query_vq, query_kmeans, validate_top_k, array_to_matrix, ivf_index, ivf_index_tdb, partition_ivf_index
5+
from .module import (
6+
query_vq,
7+
ivf_query,
8+
ivf_query_ram,
9+
validate_top_k,
10+
array_to_matrix,
11+
ivf_index,
12+
ivf_index_tdb,
13+
partition_ivf_index,
14+
)
615

716
__all__ = [
817
"FlatIndex",
@@ -11,10 +20,11 @@
1120
"load_as_matrix",
1221
"ingest",
1322
"query_vq",
14-
"query_kmeans",
23+
"ivf_query",
24+
"ivf_query_ram",
1525
"validate_top_k",
1626
"ivf_index",
1727
"ivf_index_tdb",
1828
"array_to_matrix",
19-
"partition_ivf_index"
29+
"partition_ivf_index",
2030
]

apis/python/src/tiledb/vector_search/index.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,40 @@ class IVFFlatIndex(Index):
7676
URI of datataset
7777
dtype: numpy.dtype
7878
datatype float32 or uint8
79+
memory_budget: int
80+
Main memory budget. If not provided no memory budget is applied.
7981
"""
8082

81-
def __init__(self, uri, dtype: np.dtype):
83+
def __init__(
84+
self, uri, dtype: np.dtype, memory_budget: int = -1, ctx: "Ctx" = None
85+
):
8286
self.parts_db_uri = os.path.join(uri, "parts.tdb")
8387
self.centroids_uri = os.path.join(uri, "centroids.tdb")
8488
self.index_uri = os.path.join(uri, "index.tdb")
8589
self.ids_uri = os.path.join(uri, "ids.tdb")
8690
self.dtype = dtype
91+
self.memory_budget = memory_budget
92+
self.ctx = ctx
93+
if ctx is None:
94+
self.ctx = Ctx({})
95+
96+
# TODO pass in a context
97+
if self.memory_budget == -1:
98+
self._db = load_as_matrix(self.parts_db_uri)
99+
self._ids = read_vector_u64(self.ctx, self.ids_uri)
87100

88-
ctx = Ctx({}) # TODO pass in a context
89-
self._db = load_as_matrix(self.parts_db_uri)
90101
self._centroids = load_as_matrix(self.centroids_uri)
91-
self._index = read_vector_u64(ctx, self.index_uri)
92-
self._ids = read_vector_u64(ctx, self.ids_uri)
102+
self._index = read_vector_u64(self.ctx, self.index_uri)
93103

94-
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
104+
def query(
105+
self,
106+
targets: np.ndarray,
107+
k=10,
108+
nqueries=10,
109+
nthreads=8,
110+
nprobe=1,
111+
use_nuv_implementation: bool = False,
112+
):
95113
"""
96114
Open a flat index
97115
@@ -107,21 +125,42 @@ def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
107125
Number of threads to use for queyr
108126
nprobe: int
109127
number of probes
128+
use_nuv_implementation: bool
129+
wether to use the nuv query implementation. Default: False
110130
"""
111131
assert targets.dtype == np.float32
112132

113133
targets_m = array_to_matrix(targets)
134+
if self.memory_budget == -1:
135+
r = ivf_query_ram(
136+
self.dtype,
137+
self._db,
138+
self._centroids,
139+
targets_m,
140+
self._index,
141+
self._ids,
142+
nprobe=nprobe,
143+
k_nn=k,
144+
nth=True, # ??
145+
nthreads=nthreads,
146+
ctx=self.ctx,
147+
use_nuv_implementation=use_nuv_implementation,
148+
)
149+
else:
150+
r = ivf_query(
151+
self.dtype,
152+
self.parts_db_uri,
153+
self._centroids,
154+
targets_m,
155+
self._index,
156+
self.ids_uri,
157+
nprobe=nprobe,
158+
k_nn=k,
159+
memory_budget=self.memory_budget,
160+
nth=True, # ??
161+
nthreads=nthreads,
162+
ctx=self.ctx,
163+
use_nuv_implementation=use_nuv_implementation,
164+
)
114165

115-
r = query_kmeans(
116-
self._db.dtype,
117-
self._db,
118-
self._centroids,
119-
targets_m,
120-
self._index,
121-
self._ids,
122-
nprobe=nprobe,
123-
k_nn=k,
124-
nth=True, # ??
125-
nthreads=nthreads,
126-
)
127166
return np.array(r)

0 commit comments

Comments
 (0)