55
66import numpy as np
77from tiledb .vector_search .module import *
8+ from tiledb .vector_search .storage_formats import storage_formats
89from tiledb .cloud .dag import Mode
910from typing import Any , Mapping
1011
11- CENTROIDS_ARRAY_NAME = "centroids.tdb"
12- INDEX_ARRAY_NAME = "index.tdb"
13- IDS_ARRAY_NAME = "ids.tdb"
14- PARTS_ARRAY_NAME = "parts.tdb"
15-
1612
1713def submit_local (d , func , * args , ** kwargs ):
1814 # Drop kwarg
@@ -22,7 +18,7 @@ def submit_local(d, func, *args, **kwargs):
2218
2319
2420class Index :
25- def query (self , targets : np .ndarray , k = 10 , nqueries = 10 , nthreads = 8 , nprobe = 1 ):
21+ def query (self , targets : np .ndarray , k ):
2622 raise NotImplementedError
2723
2824
@@ -36,40 +32,40 @@ class FlatIndex(Index):
3632 URI of datataset
3733 dtype: numpy.dtype
3834 datatype float32 or uint8
39- parts_name: str
40- Optional name of partitions
4135 """
4236
4337 def __init__ (
4438 self ,
4539 uri : str ,
46- dtype : Optional [np .dtype ] = None ,
47- parts_name : str = "parts.tdb" ,
4840 config : Optional [Mapping [str , Any ]] = None ,
4941 ):
5042 # If the user passes a tiledb python Config object convert to a dictionary
5143 if isinstance (config , tiledb .Config ):
5244 config = dict (config )
5345
5446 self .uri = uri
55- self .dtype = dtype
5647 self ._index = None
5748 self .ctx = Ctx (config )
5849 self .config = config
50+ group = tiledb .Group (uri , ctx = tiledb .Ctx (config ))
51+ self .storage_version = group .meta .get ("storage_version" , "0.1" )
52+ self ._db = load_as_matrix (
53+ group [storage_formats [self .storage_version ]["PARTS_ARRAY_NAME" ]].uri ,
54+ ctx = self .ctx ,
55+ config = config ,
56+ )
5957
60- self ._db = load_as_matrix (os .path .join (uri , parts_name ), ctx = self .ctx , config = config )
61-
58+ dtype = group .meta .get ("dtype" , None )
6259 if dtype is None :
6360 self .dtype = self ._db .dtype
6461 else :
65- self .dtype = dtype
62+ self .dtype = np . dtype ( dtype )
6663
6764 def query (
6865 self ,
6966 targets : np .ndarray ,
7067 k : int = 10 ,
7168 nthreads : int = 8 ,
72- nprobe : int = 1 ,
7369 query_type = "heap" ,
7470 ):
7571 """
@@ -84,9 +80,7 @@ def query(
8480 nqueries: int
8581 Number of queries
8682 nthreads: int
87- Number of threads to use for queyr
88- nprobe: int
89- number of probes
83+ Number of threads to use for query
9084 """
9185 # TODO:
9286 # - typecheck targets
@@ -123,7 +117,6 @@ class IVFFlatIndex(Index):
123117 def __init__ (
124118 self ,
125119 uri ,
126- dtype : np .dtype = None ,
127120 memory_budget : int = - 1 ,
128121 config : Optional [Mapping [str , Any ]] = None ,
129122 ):
@@ -134,31 +127,48 @@ def __init__(
134127 self .config = config
135128 self .ctx = Ctx (config )
136129 group = tiledb .Group (uri , ctx = tiledb .Ctx (config ))
137- self .parts_db_uri = group [PARTS_ARRAY_NAME ].uri
138- self .centroids_uri = group [CENTROIDS_ARRAY_NAME ].uri
139- self .index_uri = group [INDEX_ARRAY_NAME ].uri
140- self .ids_uri = group [IDS_ARRAY_NAME ].uri
130+ self .storage_version = group .meta .get ("storage_version" , "0.1" )
131+ self .parts_db_uri = group [
132+ storage_formats [self .storage_version ]["PARTS_ARRAY_NAME" ]
133+ ].uri
134+ self .centroids_uri = group [
135+ storage_formats [self .storage_version ]["CENTROIDS_ARRAY_NAME" ]
136+ ].uri
137+ self .index_uri = group [
138+ storage_formats [self .storage_version ]["INDEX_ARRAY_NAME" ]
139+ ].uri
140+ self .ids_uri = group [
141+ storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ]
142+ ].uri
141143 self .memory_budget = memory_budget
142144
145+ self ._centroids = load_as_matrix (
146+ self .centroids_uri , ctx = self .ctx , config = config
147+ )
148+ self ._index = read_vector_u64 (self .ctx , self .index_uri )
149+
143150 # TODO pass in a context
144151 if self .memory_budget == - 1 :
145152 self ._db = load_as_matrix (self .parts_db_uri , ctx = self .ctx , config = config )
146153 self ._ids = read_vector_u64 (self .ctx , self .ids_uri )
147154
148- self ._centroids = load_as_matrix (self .centroids_uri , ctx = self .ctx , config = config )
149-
150- # TODO this should always be available
155+ dtype = group .meta .get ("dtype" , None )
151156 if dtype is None :
152- self .dtype = self ._centroids .dtype
157+ schema = tiledb .ArraySchema .load (self .parts_db_uri )
158+ self .dtype = np .dtype (schema .attr ("values" ).dtype )
153159 else :
154- self .dtype = dtype
155- self ._index = read_vector_u64 (self .ctx , self .index_uri )
160+ self .dtype = np .dtype (dtype )
161+
162+ self .partitions = group .meta .get ("partitions" , - 1 )
163+ if self .partitions == - 1 :
164+ schema = tiledb .ArraySchema .load (self .centroids_uri )
165+ self .partitions = schema .domain .dim ("cols" ).domain [1 ] + 1
156166
157167 def query (
158168 self ,
159169 queries : np .ndarray ,
160170 k : int = 10 ,
161- nprobe : int = 10 ,
171+ nprobe : int = 1 ,
162172 nthreads : int = - 1 ,
163173 use_nuv_implementation : bool = False ,
164174 mode : Mode = None ,
@@ -198,6 +208,8 @@ def query(
198208
199209 if nthreads == - 1 :
200210 nthreads = multiprocessing .cpu_count ()
211+
212+ nprobe = min (nprobe , self .partitions )
201213 if mode is None :
202214 queries_m = array_to_matrix (np .transpose (queries ))
203215 if self .memory_budget == - 1 :
@@ -313,7 +325,7 @@ def dist_qv_udf(
313325 active_queries = active_queries ,
314326 indices = indices ,
315327 k_nn = k_nn ,
316- ctx = Ctx (config )
328+ ctx = Ctx (config ),
317329 )
318330 results = []
319331 for q in range (len (r )):
@@ -377,9 +389,7 @@ def dist_qv_udf(
377389 ids_uri = self .ids_uri ,
378390 query_vectors = queries ,
379391 active_partitions = np .array (active_partitions )[part :part_end ],
380- active_queries = np .array (
381- aq , dtype = object
382- ),
392+ active_queries = np .array (aq , dtype = object ),
383393 indices = np .array (self ._index ),
384394 k_nn = k ,
385395 config = config ,
@@ -406,5 +416,5 @@ def dist_qv_udf(
406416 tmp = sorted (tmp_results , key = lambda t : t [0 ])[0 :k ]
407417 for j in range (len (tmp ), k ):
408418 tmp .append ((float (0.0 ), int (0 )))
409- results_per_query .append (np .array (tmp , dtype = np .dtype (' float,int' ))['f1' ])
419+ results_per_query .append (np .array (tmp , dtype = np .dtype (" float,int" ))["f1" ])
410420 return results_per_query
0 commit comments