Skip to content

Commit 9d3ee02

Browse files
author
Nikos Papailiou
committed
Fix
1 parent bd2357c commit 9d3ee02

File tree

6 files changed

+43
-35
lines changed

6 files changed

+43
-35
lines changed

.github/workflows/quarto-render.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
- name: "Quarto render"
3737
shell: bash
3838
run: |
39-
pip install quartodoc PyYAML click
39+
pip install quartodoc PyYAML click griffe==0.32.3
4040
# create a symlink to the tiledbvcf python package, so it doesn't have to be installed
4141
#ln -s apis/python/src/tiledb/vector_search
4242
quartodoc build

_quarto.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ quartodoc:
3939
- index.Index
4040
- flat_index.FlatIndex
4141
- ivf_flat_index.IVFFlatIndex
42-
- ingestion.ingest
42+
- ingestion
4343

4444
website:
4545
favicon: "documentation/assets/tiledb.ico"

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class FlatIndex(Index):
1313
Parameters
1414
----------
1515
uri: str
16-
URI of datataset
17-
config: None
16+
URI of the index
17+
config: Optional[Mapping[str, Any]]
1818
config dictionary, defaults to None
1919
"""
2020

@@ -52,7 +52,7 @@ def __init__(
5252

5353
def query_internal(
5454
self,
55-
targets: np.ndarray,
55+
queries: np.ndarray,
5656
k: int = 10,
5757
nthreads: int = 8,
5858
):
@@ -61,22 +61,20 @@ def query_internal(
6161
6262
Parameters
6363
----------
64-
targets: numpy.ndarray
65-
ND Array of query targets
64+
queries: numpy.ndarray
65+
ND Array of queries
6666
k: int
67-
Number of top results to return per target
68-
nqueries: int
69-
Number of queries
67+
Number of top results to return per query
7068
nthreads: int
7169
Number of threads to use for query
7270
"""
7371
# TODO:
74-
# - typecheck targets
72+
# - typecheck queries
7573
# - add all the options and query strategies
7674

77-
assert targets.dtype == np.float32
75+
assert queries.dtype == np.float32
7876

79-
targets_m = array_to_matrix(np.transpose(targets))
80-
r = query_vq_heap(self._db, targets_m, self._ids, k, nthreads)
77+
queries_m = array_to_matrix(np.transpose(queries))
78+
r = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)
8179

8280
return np.transpose(np.array(r))

apis/python/src/tiledb/vector_search/index.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@
88

99

1010
class Index:
11+
"""
12+
Open a Vector index
13+
14+
Parameters
15+
----------
16+
uri: str
17+
URI of the index
18+
config: Optional[Mapping[str, Any]]
19+
config dictionary, defaults to None
20+
"""
1121
def __init__(
1222
self,
13-
uri,
23+
uri: str,
1424
config: Optional[Mapping[str, Any]] = None,
1525
):
1626
# If the user passes a tiledb python Config object convert to a dictionary
@@ -26,14 +36,14 @@ def __init__(
2636
self.index_version = self.group.meta.get("index_version", "")
2737

2838

29-
def query(self, targets: np.ndarray, k, **kwargs):
39+
def query(self, queries: np.ndarray, k, **kwargs):
3040
# TODO merge results based on scores and use higher k to improve retrieval
3141
updated_ids = set(self.read_updated_ids())
32-
internal_results = self.query_internal(targets, k, **kwargs)
42+
internal_results = self.query_internal(queries, k, **kwargs)
3343
if self.update_arrays_uri is None:
3444
return internal_results
35-
addition_results = self.query_additions(targets, k)
36-
merged_results = np.zeros((targets.shape[0], k), dtype=np.uint64)
45+
addition_results = self.query_additions(queries, k)
46+
merged_results = np.zeros((queries.shape[0], k), dtype=np.uint64)
3747
query_id = 0
3848
for query in internal_results:
3949
res_id = 0
@@ -48,17 +58,17 @@ def query(self, targets: np.ndarray, k, **kwargs):
4858
query_id += 1
4959
return merged_results
5060

51-
def query_internal(self, targets: np.ndarray, k, **kwargs):
61+
def query_internal(self, queries: np.ndarray, k, **kwargs):
5262
raise NotImplementedError
5363

54-
def query_additions(self, targets: np.ndarray, k):
55-
assert targets.dtype == np.float32
64+
def query_additions(self, queries: np.ndarray, k):
65+
assert queries.dtype == np.float32
5666

5767
additions_vectors, additions_external_ids = self.read_additions()
58-
targets_m = array_to_matrix(np.transpose(targets))
68+
queries_m = array_to_matrix(np.transpose(queries))
5969
r = query_vq_heap_pyarray(
6070
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
61-
targets_m,
71+
queries_m,
6272
StdVector_u64(additions_external_ids),
6373
k,
6474
8)

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from functools import partial
33

44
from tiledb.cloud.dag import Mode
5-
from tiledb.vector_search.index import Index
6-
from tiledb.vector_search.flat_index import FlatIndex
7-
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
85
from tiledb.vector_search._tiledbvspy import *
96
import numpy as np
107

@@ -31,7 +28,7 @@ def ingest(
3128
verbose: bool = False,
3229
trace_id: Optional[str] = None,
3330
mode: Mode = Mode.LOCAL,
34-
) -> Index:
31+
):
3532
"""
3633
Ingest vectors into TileDB.
3734
@@ -100,6 +97,9 @@ def ingest(
10097
from tiledb.cloud.utilities import get_logger
10198
from tiledb.cloud.utilities import set_aws_context
10299
from tiledb.vector_search.storage_formats import storage_formats, STORAGE_VERSION
100+
from tiledb.vector_search.index import Index
101+
from tiledb.vector_search.flat_index import FlatIndex
102+
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
103103

104104
# use index_group_uri for internal clarity
105105
index_group_uri = index_uri

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ class IVFFlatIndex(Index):
2222
Parameters
2323
----------
2424
uri: str
25-
URI of datataset
26-
config: None
25+
URI of the index
26+
config: Optional[Mapping[str, Any]]
2727
config dictionary, defaults to None
2828
memory_budget: int
2929
Main memory budget. If not provided, no memory budget is applied.
3030
"""
3131

3232
def __init__(
3333
self,
34-
uri,
34+
uri: str,
3535
config: Optional[Mapping[str, Any]] = None,
3636
memory_budget: int = -1,
3737
):
@@ -100,7 +100,7 @@ def query_internal(
100100
queries: numpy.ndarray
101101
ND Array of queries
102102
k: int
103-
Number of top results to return per target
103+
Number of top results to return per query
104104
nprobe: int
105105
number of probes
106106
nthreads: int
@@ -193,13 +193,11 @@ def taskgraph_query(
193193
queries: numpy.ndarray
194194
ND Array of queries
195195
k: int
196-
Number of top results to return per target
196+
Number of top results to return per query
197197
nprobe: int
198198
number of probes
199199
nthreads: int
200200
Number of threads to use for query
201-
use_nuv_implementation: bool
202-
wether to use the nuv query implementation. Default: False
203201
mode: Mode
204202
If provided the query will be executed using TileDB cloud taskgraphs.
205203
For distributed execution you can use REALTIME or BATCH mode
@@ -209,6 +207,8 @@ def taskgraph_query(
209207
num_workers: int
210208
Only relevant for taskgraph based execution.
211209
If provided, this is the number of workers to use for the query execution.
210+
config: None
211+
config dictionary, defaults to None
212212
"""
213213
from tiledb.cloud import dag
214214
from tiledb.cloud.dag import Mode

0 commit comments

Comments
 (0)